structure saas with tools
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.datasets.dataset import _Dataset
|
||||
from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
|
||||
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
|
||||
from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset
|
||||
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
|
||||
from google.cloud.aiplatform.datasets.text_dataset import TextDataset
|
||||
from google.cloud.aiplatform.datasets.video_dataset import VideoDataset
|
||||
|
||||
|
||||
__all__ = (
|
||||
"_Dataset",
|
||||
"_ColumnNamesDataset",
|
||||
"TabularDataset",
|
||||
"TimeSeriesDataset",
|
||||
"ImageDataset",
|
||||
"TextDataset",
|
||||
"VideoDataset",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,240 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
from typing import Optional, Dict, Sequence, Union
|
||||
from google.cloud.aiplatform import schema
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
io as gca_io,
|
||||
dataset as gca_dataset,
|
||||
)
|
||||
|
||||
|
||||
class Datasource(abc.ABC):
|
||||
"""An abstract class that sets dataset_metadata."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def dataset_metadata(self):
|
||||
"""Dataset Metadata."""
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceImportable(abc.ABC):
|
||||
"""An abstract class that sets import_data_config."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def import_data_config(self):
|
||||
"""Import Data Config."""
|
||||
pass
|
||||
|
||||
|
||||
class TabularDatasource(Datasource):
|
||||
"""Datasource for creating a tabular dataset for Vertex AI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
):
|
||||
"""Creates a tabular datasource.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Cloud Storage URI of one or more files. Only CSV files are supported.
|
||||
The first line of the CSV file is used as the header.
|
||||
If there are multiple files, the header is the first line of
|
||||
the lexicographically first file, the other files must either
|
||||
contain the exact same header or omit the header.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
The URI of a BigQuery table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
|
||||
Raises:
|
||||
ValueError: If source configuration is not valid.
|
||||
"""
|
||||
|
||||
dataset_metadata = None
|
||||
|
||||
if gcs_source and isinstance(gcs_source, str):
|
||||
gcs_source = [gcs_source]
|
||||
|
||||
if gcs_source and bq_source:
|
||||
raise ValueError("Only one of gcs_source or bq_source can be set.")
|
||||
|
||||
if not any([gcs_source, bq_source]):
|
||||
raise ValueError("One of gcs_source or bq_source must be set.")
|
||||
|
||||
if gcs_source:
|
||||
dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
|
||||
elif bq_source:
|
||||
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}
|
||||
|
||||
self._dataset_metadata = dataset_metadata
|
||||
|
||||
@property
|
||||
def dataset_metadata(self) -> Optional[Dict]:
|
||||
"""Dataset Metadata."""
|
||||
return self._dataset_metadata
|
||||
|
||||
|
||||
class NonTabularDatasource(Datasource):
|
||||
"""Datasource for creating an empty non-tabular dataset for Vertex AI."""
|
||||
|
||||
@property
|
||||
def dataset_metadata(self) -> Optional[Dict]:
|
||||
return None
|
||||
|
||||
|
||||
class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable):
|
||||
"""Datasource for creating a non-tabular dataset for Vertex AI and
|
||||
importing data to the dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gcs_source: Union[str, Sequence[str]],
|
||||
import_schema_uri: str,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
):
|
||||
"""Creates a non-tabular datasource.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Required. The Google Cloud Storage location for the input content.
|
||||
Google Cloud Storage URI(-s) to the input file(s).
|
||||
|
||||
Examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
import_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file refenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
"""
|
||||
super().__init__()
|
||||
self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source
|
||||
self._import_schema_uri = import_schema_uri
|
||||
self._data_item_labels = data_item_labels
|
||||
|
||||
@property
|
||||
def import_data_config(self) -> gca_dataset.ImportDataConfig:
|
||||
"""Import Data Config."""
|
||||
return gca_dataset.ImportDataConfig(
|
||||
gcs_source=gca_io.GcsSource(uris=self._gcs_source),
|
||||
import_schema_uri=self._import_schema_uri,
|
||||
data_item_labels=self._data_item_labels,
|
||||
)
|
||||
|
||||
|
||||
def create_datasource(
|
||||
metadata_schema_uri: str,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
) -> Datasource:
|
||||
"""Creates a datasource
|
||||
Args:
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
import_schema_uri (str):
|
||||
Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The Google Cloud Storage location for the input content.
|
||||
Google Cloud Storage URI(-s) to the input file(s).
|
||||
|
||||
Examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
BigQuery URI to the input table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file refenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
|
||||
Returns:
|
||||
datasource (Datasource)
|
||||
|
||||
Raises:
|
||||
ValueError: When below scenarios happen:
|
||||
- import_schema_uri is identified for creating TabularDatasource
|
||||
- either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable
|
||||
"""
|
||||
|
||||
if metadata_schema_uri == schema.dataset.metadata.tabular:
|
||||
if import_schema_uri:
|
||||
raise ValueError("tabular dataset does not support data import.")
|
||||
return TabularDatasource(gcs_source, bq_source)
|
||||
|
||||
if metadata_schema_uri == schema.dataset.metadata.time_series:
|
||||
if import_schema_uri:
|
||||
raise ValueError("time series dataset does not support data import.")
|
||||
return TabularDatasource(gcs_source, bq_source)
|
||||
|
||||
if not import_schema_uri and not gcs_source:
|
||||
return NonTabularDatasource()
|
||||
elif import_schema_uri and gcs_source:
|
||||
return NonTabularDatasourceImportable(
|
||||
gcs_source, import_schema_uri, data_item_labels
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"nontabular dataset requires both import_schema_uri and gcs_source for data import."
|
||||
)
|
||||
@@ -0,0 +1,261 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import csv
|
||||
import logging
|
||||
from typing import List, Optional, Set, TYPE_CHECKING
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud import storage
|
||||
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform import datasets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import bigquery
|
||||
|
||||
|
||||
class _ColumnNamesDataset(datasets._Dataset):
|
||||
@property
|
||||
def column_names(self) -> List[str]:
|
||||
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
|
||||
Google BigQuery source.
|
||||
|
||||
Returns:
|
||||
List[str]
|
||||
A list of columns names
|
||||
|
||||
Raises:
|
||||
RuntimeError: When no valid source is found.
|
||||
"""
|
||||
|
||||
self._assert_gca_resource_is_available()
|
||||
|
||||
metadata = self._gca_resource.metadata
|
||||
|
||||
if metadata is None:
|
||||
raise RuntimeError("No metadata found for dataset")
|
||||
|
||||
input_config = metadata.get("inputConfig")
|
||||
|
||||
if input_config is None:
|
||||
raise RuntimeError("No inputConfig found for dataset")
|
||||
|
||||
gcs_source = input_config.get("gcsSource")
|
||||
bq_source = input_config.get("bigquerySource")
|
||||
|
||||
if gcs_source:
|
||||
gcs_source_uris = gcs_source.get("uri")
|
||||
|
||||
if gcs_source_uris and len(gcs_source_uris) > 0:
|
||||
# Lexicographically sort the files
|
||||
gcs_source_uris.sort()
|
||||
|
||||
# Get the first file in sorted list
|
||||
# TODO(b/193044977): Return as Set instead of List
|
||||
return list(
|
||||
self._retrieve_gcs_source_columns(
|
||||
project=self.project,
|
||||
gcs_csv_file_path=gcs_source_uris[0],
|
||||
credentials=self.credentials,
|
||||
)
|
||||
)
|
||||
elif bq_source:
|
||||
bq_table_uri = bq_source.get("uri")
|
||||
if bq_table_uri:
|
||||
# TODO(b/193044977): Return as Set instead of List
|
||||
return list(
|
||||
self._retrieve_bq_source_columns(
|
||||
project=self.project,
|
||||
bq_table_uri=bq_table_uri,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError("No valid CSV or BigQuery datasource found.")
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_gcs_source_columns(
|
||||
project: str,
|
||||
gcs_csv_file_path: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Set[str]:
|
||||
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
|
||||
|
||||
Example Usage:
|
||||
|
||||
column_names = _retrieve_gcs_source_columns(
|
||||
"project_id",
|
||||
"gs://example-bucket/path/to/csv_file"
|
||||
)
|
||||
|
||||
# column_names = {"column_1", "column_2"}
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the Google Cloud Storage client with.
|
||||
gcs_csv_file_path (str):
|
||||
Required. A full path to a CSV files stored on Google Cloud Storage.
|
||||
Must include "gs://" prefix.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use to with GCS Client.
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of columns names in the CSV file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: When the retrieved CSV file is invalid.
|
||||
"""
|
||||
|
||||
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
|
||||
gcs_csv_file_path
|
||||
)
|
||||
client = storage.Client(project=project, credentials=credentials)
|
||||
bucket = client.bucket(gcs_bucket)
|
||||
blob = bucket.blob(gcs_blob)
|
||||
|
||||
# Incrementally download the CSV file until the header is retrieved
|
||||
first_new_line_index = -1
|
||||
start_index = 0
|
||||
increment = 1000
|
||||
line = ""
|
||||
|
||||
try:
|
||||
logger = logging.getLogger("google.resumable_media._helpers")
|
||||
logging_warning_filter = utils.LoggingFilter(logging.INFO)
|
||||
logger.addFilter(logging_warning_filter)
|
||||
|
||||
while first_new_line_index == -1:
|
||||
line += blob.download_as_bytes(
|
||||
start=start_index, end=start_index + increment - 1
|
||||
).decode("utf-8")
|
||||
|
||||
first_new_line_index = line.find("\n")
|
||||
start_index += increment
|
||||
|
||||
header_line = line[:first_new_line_index]
|
||||
|
||||
# Split to make it an iterable
|
||||
header_line = header_line.split("\n")[:1]
|
||||
|
||||
csv_reader = csv.reader(header_line, delimiter=",")
|
||||
except (ValueError, RuntimeError) as err:
|
||||
raise RuntimeError(
|
||||
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
|
||||
gcs_csv_file_path, err
|
||||
)
|
||||
) from err
|
||||
finally:
|
||||
logger.removeFilter(logging_warning_filter)
|
||||
|
||||
return set(next(csv_reader))
|
||||
|
||||
@staticmethod
|
||||
def _get_bq_schema_field_names_recursively(
|
||||
schema_field: "bigquery.SchemaField",
|
||||
) -> Set[str]:
|
||||
"""Retrieve the name for a schema field along with ancestor fields.
|
||||
Nested schema fields are flattened and concatenated with a ".".
|
||||
Schema fields with child fields are not included, but the children are.
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the BigQuery client with.
|
||||
bq_table_uri (str):
|
||||
Required. A URI to a BigQuery table.
|
||||
Can include "bq://" prefix but not required.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use with BQ Client.
|
||||
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of columns names in the BigQuery table.
|
||||
"""
|
||||
|
||||
ancestor_names = {
|
||||
nested_field_name
|
||||
for field in schema_field.fields
|
||||
for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
|
||||
field
|
||||
)
|
||||
}
|
||||
|
||||
# Only return "leaf nodes", basically any field that doesn't have children
|
||||
if len(ancestor_names) == 0:
|
||||
return {schema_field.name}
|
||||
else:
|
||||
return {f"{schema_field.name}.{name}" for name in ancestor_names}
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_bq_source_columns(
|
||||
project: str,
|
||||
bq_table_uri: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Set[str]:
|
||||
"""Retrieve the column names from a table on Google BigQuery
|
||||
Nested schema fields are flattened and concatenated with a ".".
|
||||
Schema fields with child fields are not included, but the children are.
|
||||
|
||||
Example Usage:
|
||||
|
||||
column_names = _retrieve_bq_source_columns(
|
||||
"project_id",
|
||||
"bq://project_id.dataset.table"
|
||||
)
|
||||
|
||||
# column_names = {"column_1", "column_2", "column_3.nested_field"}
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the BigQuery client with.
|
||||
bq_table_uri (str):
|
||||
Required. A URI to a BigQuery table.
|
||||
Can include "bq://" prefix but not required.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use with BQ Client.
|
||||
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of column names in the BigQuery table.
|
||||
"""
|
||||
|
||||
# Remove bq:// prefix
|
||||
prefix = "bq://"
|
||||
if bq_table_uri.startswith(prefix):
|
||||
bq_table_uri = bq_table_uri[len(prefix) :]
|
||||
|
||||
# The colon-based "project:dataset.table" format is no longer supported:
|
||||
# Invalid dataset ID "bigquery-public-data:chicago_taxi_trips".
|
||||
# Dataset IDs must be alphanumeric (plus underscores and dashes) and must be at most 1024 characters long.
|
||||
# Using dot-based "project.dataset.table" format instead.
|
||||
bq_table_uri = bq_table_uri.replace(":", ".")
|
||||
|
||||
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
|
||||
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
|
||||
|
||||
client = bigquery.Client(project=project, credentials=credentials)
|
||||
table = client.get_table(bq_table_uri)
|
||||
schema = table.schema
|
||||
|
||||
return {
|
||||
field_name
|
||||
for field in schema
|
||||
for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
|
||||
field
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,927 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.api_core import operation
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
from google.cloud.aiplatform.compat.services import dataset_service_client
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
dataset as gca_dataset,
|
||||
dataset_service as gca_dataset_service,
|
||||
encryption_spec as gca_encryption_spec,
|
||||
io as gca_io,
|
||||
)
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.protobuf import field_mask_pb2
|
||||
from google.protobuf import json_format
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class _Dataset(base.VertexAiResourceNounWithFutureManager):
|
||||
"""Managed dataset resource for Vertex AI."""
|
||||
|
||||
client_class = utils.DatasetClientWithOverride
|
||||
_resource_noun = "datasets"
|
||||
_getter_method = "get_dataset"
|
||||
_list_method = "list_datasets"
|
||||
_delete_method = "delete_dataset"
|
||||
_parse_resource_name_method = "parse_dataset_path"
|
||||
_format_resource_name_method = "dataset_path"
|
||||
|
||||
_supported_metadata_schema_uris: Tuple[str] = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing managed dataset given a dataset name or ID.
|
||||
|
||||
Args:
|
||||
dataset_name (str):
|
||||
Required. A fully-qualified dataset resource name or dataset ID.
|
||||
Example: "projects/123/locations/us-central1/datasets/456" or
|
||||
"456" when project and location are initialized or passed.
|
||||
project (str):
|
||||
Optional project to retrieve dataset from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional location to retrieve dataset from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to retrieve this Dataset. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
resource_name=dataset_name,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(resource_name=dataset_name)
|
||||
self._validate_metadata_schema_uri()
|
||||
|
||||
@property
|
||||
def metadata_schema_uri(self) -> str:
|
||||
"""The metadata schema uri of this dataset resource."""
|
||||
self._assert_gca_resource_is_available()
|
||||
return self._gca_resource.metadata_schema_uri
|
||||
|
||||
def _validate_metadata_schema_uri(self) -> None:
|
||||
"""Validate the metadata_schema_uri of retrieved dataset resource.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset type of the retrieved dataset resource is
|
||||
not supported by the class.
|
||||
"""
|
||||
if self._supported_metadata_schema_uris and (
|
||||
self.metadata_schema_uri not in self._supported_metadata_schema_uris
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} class can not be used to retrieve "
|
||||
f"dataset resource {self.resource_name}, check the dataset type"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
# TODO(b/223262536): Make the display_name parameter optional in the next major release
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Creates a new dataset and optionally imports data into dataset when
|
||||
source and import_schema_uri are passed.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Google Cloud Storage URI(-s) to the
|
||||
input file(s). May contain wildcards. For more
|
||||
information on wildcards, see
|
||||
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
BigQuery URI to the input table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
import_schema_uri (str):
|
||||
Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
Object <https://tinyurl.com/y538mdwt>`__.
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file referenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
This arg is not for specifying the annotation name or the
|
||||
training target of your data, but for some global labels of
|
||||
the dataset. E.g.,
|
||||
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
|
||||
specifies that all the uploaded data are used for training.
|
||||
project (str):
|
||||
Project to upload this dataset to. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location to upload this dataset to. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to upload this dataset. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the request as metadata.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your datasets.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Dataset
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the dataset. Has the
|
||||
form:
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@base.optional_sync()
|
||||
def _create_and_import(
|
||||
cls,
|
||||
api_client: dataset_service_client.DatasetServiceClient,
|
||||
parent: str,
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
datasource: _datasources.Datasource,
|
||||
project: str,
|
||||
location: str,
|
||||
credentials: Optional[auth_credentials.Credentials],
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Creates a new dataset and optionally imports data into dataset when
|
||||
source and import_schema_uri are passed.
|
||||
|
||||
Args:
|
||||
api_client (dataset_service_client.DatasetServiceClient):
|
||||
An instance of DatasetServiceClient with the correct api_endpoint
|
||||
already set based on user's preferences.
|
||||
parent (str):
|
||||
Required. Also known as common location path, that usually contains the
|
||||
project and location that the user provided to the upstream method.
|
||||
Example: "projects/my-prj/locations/us-central1"
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
datasource (_datasources.Datasource):
|
||||
Required. Datasource for creating a dataset for Vertex AI.
|
||||
project (str):
|
||||
Required. Project to upload this model to. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Required. Location to upload this model to. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (Optional[auth_credentials.Credentials]):
|
||||
Custom credentials to use to upload this model. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the request as metadata.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
|
||||
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
|
||||
create_dataset_lro = cls._create(
|
||||
api_client=api_client,
|
||||
parent=parent,
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=encryption_spec,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
_LOGGER.log_create_with_lro(cls, create_dataset_lro)
|
||||
|
||||
created_dataset = create_dataset_lro.result(timeout=None)
|
||||
|
||||
_LOGGER.log_create_complete(cls, created_dataset, "ds")
|
||||
|
||||
dataset_obj = cls(
|
||||
dataset_name=created_dataset.name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Import if import datasource is DatasourceImportable
|
||||
if isinstance(datasource, _datasources.DatasourceImportable):
|
||||
dataset_obj._import_and_wait(
|
||||
datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
|
||||
return dataset_obj
|
||||
|
||||
def _import_and_wait(
|
||||
self,
|
||||
datasource,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
):
|
||||
_LOGGER.log_action_start_against_resource(
|
||||
"Importing",
|
||||
"data",
|
||||
self,
|
||||
)
|
||||
|
||||
import_lro = self._import(
|
||||
datasource=datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"Import", "data", self.__class__, import_lro
|
||||
)
|
||||
|
||||
import_lro.result(timeout=None)
|
||||
|
||||
_LOGGER.log_action_completed_against_resource("data", "imported", self)
|
||||
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
api_client: dataset_service_client.DatasetServiceClient,
|
||||
parent: str,
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
datasource: _datasources.Datasource,
|
||||
request_metadata: Sequence[Tuple[str, str]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> operation.Operation:
|
||||
"""Creates a new managed dataset by directly calling API client.
|
||||
|
||||
Args:
|
||||
api_client (dataset_service_client.DatasetServiceClient):
|
||||
An instance of DatasetServiceClient with the correct api_endpoint
|
||||
already set based on user's preferences.
|
||||
parent (str):
|
||||
Required. Also known as common location path, that usually contains the
|
||||
project and location that the user provided to the upstream method.
|
||||
Example: "projects/my-prj/locations/us-central1"
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
datasource (_datasources.Datasource):
|
||||
Required. Datasource for creating a dataset for Vertex AI.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the create_dataset
|
||||
request as metadata. Usually to specify special dataset config.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
|
||||
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
Returns:
|
||||
operation (Operation):
|
||||
An object representing a long-running operation.
|
||||
"""
|
||||
|
||||
gapic_dataset = gca_dataset.Dataset(
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
metadata=datasource.dataset_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=encryption_spec,
|
||||
)
|
||||
|
||||
return api_client.create_dataset(
|
||||
parent=parent,
|
||||
dataset=gapic_dataset,
|
||||
metadata=request_metadata,
|
||||
timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
def _import(
|
||||
self,
|
||||
datasource: _datasources.DatasourceImportable,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> operation.Operation:
|
||||
"""Imports data into managed dataset by directly calling API client.
|
||||
|
||||
Args:
|
||||
datasource (_datasources.DatasourceImportable):
|
||||
Required. Datasource for importing data to an existing dataset for Vertex AI.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
operation (Operation):
|
||||
An object representing a long-running operation.
|
||||
"""
|
||||
return self.api_client.import_data(
|
||||
name=self.resource_name,
|
||||
import_configs=[datasource.import_data_config],
|
||||
timeout=import_request_timeout,
|
||||
)
|
||||
|
||||
@base.optional_sync(return_input_arg="self")
|
||||
def import_data(
|
||||
self,
|
||||
gcs_source: Union[str, Sequence[str]],
|
||||
import_schema_uri: str,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
sync: bool = True,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Upload data to existing managed dataset.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Required. Google Cloud Storage URI(-s) to the
|
||||
input file(s). May contain wildcards. For more
|
||||
information on wildcards, see
|
||||
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
import_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
Object <https://tinyurl.com/y538mdwt>`__.
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file referenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
This arg is not for specifying the annotation name or the
|
||||
training target of your data, but for some global labels of
|
||||
the dataset. E.g.,
|
||||
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
|
||||
specifies that all the uploaded data are used for training.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=self.metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
self._import_and_wait(
|
||||
datasource=datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_and_convert_export_split(
|
||||
self,
|
||||
split: Union[Dict[str, str], Dict[str, float]],
|
||||
) -> Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]:
|
||||
"""
|
||||
Validates the split for data export. Valid splits are dicts
|
||||
encoding the contents of proto messages ExportFilterSplit or
|
||||
ExportFractionSplit. If the split is valid, this function returns
|
||||
the corresponding convertered proto message.
|
||||
|
||||
split (Union[Dict[str, str], Dict[str, float]]):
|
||||
The instructions how the export data should be split between the
|
||||
training, validation and test sets.
|
||||
"""
|
||||
if len(split) != 3:
|
||||
raise ValueError(
|
||||
"The provided split for data export does not provide enough"
|
||||
"information. It must have three fields, mapping to training,"
|
||||
"validation and test splits respectively."
|
||||
)
|
||||
|
||||
if not ("training_filter" in split or "training_fraction" in split):
|
||||
raise ValueError(
|
||||
"The provided filter for data export does not provide enough"
|
||||
"information. It must have three fields, mapping to training,"
|
||||
"validation and test respectively."
|
||||
)
|
||||
|
||||
if "training_filter" in split:
|
||||
if (
|
||||
"validation_filter" in split
|
||||
and "test_filter" in split
|
||||
and isinstance(split["training_filter"], str)
|
||||
and isinstance(split["validation_filter"], str)
|
||||
and isinstance(split["test_filter"], str)
|
||||
):
|
||||
return gca_dataset.ExportFilterSplit(
|
||||
training_filter=split["training_filter"],
|
||||
validation_filter=split["validation_filter"],
|
||||
test_filter=split["test_filter"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided ExportFilterSplit does not contain all"
|
||||
"three required fields: training_filter, "
|
||||
"validation_filter and test_filter."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
"validation_fraction" in split
|
||||
and "test_fraction" in split
|
||||
and isinstance(split["training_fraction"], float)
|
||||
and isinstance(split["validation_fraction"], float)
|
||||
and isinstance(split["test_fraction"], float)
|
||||
):
|
||||
return gca_dataset.ExportFractionSplit(
|
||||
training_fraction=split["training_fraction"],
|
||||
validation_fraction=split["validation_fraction"],
|
||||
test_fraction=split["test_fraction"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided ExportFractionSplit does not contain all"
|
||||
"three required fields: training_fraction, "
|
||||
"validation_fraction and test_fraction."
|
||||
)
|
||||
|
||||
def _get_completed_export_data_operation(
|
||||
self,
|
||||
output_dir: str,
|
||||
export_use: Optional[gca_dataset.ExportDataConfig.ExportUse] = None,
|
||||
annotation_filter: Optional[str] = None,
|
||||
saved_query_id: Optional[str] = None,
|
||||
annotation_schema_uri: Optional[str] = None,
|
||||
split: Optional[
|
||||
Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]
|
||||
] = None,
|
||||
) -> gca_dataset_service.ExportDataResponse:
|
||||
self.wait()
|
||||
|
||||
# TODO(b/171311614): Add support for BigQuery export path
|
||||
export_data_config = gca_dataset.ExportDataConfig(
|
||||
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
|
||||
)
|
||||
if export_use is not None:
|
||||
export_data_config.export_use = export_use
|
||||
if annotation_filter is not None:
|
||||
export_data_config.annotation_filter = annotation_filter
|
||||
if saved_query_id is not None:
|
||||
export_data_config.saved_query_id = saved_query_id
|
||||
if annotation_schema_uri is not None:
|
||||
export_data_config.annotation_schema_uri = annotation_schema_uri
|
||||
if split is not None:
|
||||
if isinstance(split, gca_dataset.ExportFilterSplit):
|
||||
export_data_config.filter_split = split
|
||||
elif isinstance(split, gca_dataset.ExportFractionSplit):
|
||||
export_data_config.fraction_split = split
|
||||
|
||||
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
|
||||
|
||||
export_lro = self.api_client.export_data(
|
||||
name=self.resource_name, export_config=export_data_config
|
||||
)
|
||||
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"Export", "data", self.__class__, export_lro
|
||||
)
|
||||
|
||||
export_data_response = export_lro.result()
|
||||
|
||||
_LOGGER.log_action_completed_against_resource("data", "export", self)
|
||||
|
||||
return export_data_response
|
||||
|
||||
# TODO(b/174751568) add optional sync support
|
||||
def export_data(self, output_dir: str) -> Sequence[str]:
|
||||
"""Exports data to output dir to GCS.
|
||||
|
||||
Args:
|
||||
output_dir (str):
|
||||
Required. The Google Cloud Storage location where the output is to
|
||||
be written to. In the given directory a new directory will be
|
||||
created with name:
|
||||
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
|
||||
where timestamp is in YYYYMMDDHHMMSS format. All export
|
||||
output will be written into that directory. Inside that
|
||||
directory, annotations with the same schema will be grouped
|
||||
into sub directories which are named with the corresponding
|
||||
annotations' schema title. Inside these sub directories, a
|
||||
schema.yaml will be created to describe the output format.
|
||||
|
||||
If the uri doesn't end with '/', a '/' will be automatically
|
||||
appended. The directory is created if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
exported_files (Sequence[str]):
|
||||
All of the files that are exported in this export operation.
|
||||
"""
|
||||
return self._get_completed_export_data_operation(output_dir).exported_files
|
||||
|
||||
def export_data_for_custom_training(
|
||||
self,
|
||||
output_dir: str,
|
||||
annotation_filter: Optional[str] = None,
|
||||
saved_query_id: Optional[str] = None,
|
||||
annotation_schema_uri: Optional[str] = None,
|
||||
split: Optional[Union[Dict[str, str], Dict[str, float]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Exports data to output dir to GCS for custom training use case.
|
||||
|
||||
Example annotation_schema_uri (image classification):
|
||||
gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
|
||||
|
||||
Example split (filter split):
|
||||
{
|
||||
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
|
||||
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
|
||||
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
|
||||
}
|
||||
Example split (fraction split):
|
||||
{
|
||||
"training_fraction": 0.7,
|
||||
"validation_fraction": 0.2,
|
||||
"test_fraction": 0.1,
|
||||
}
|
||||
|
||||
Args:
|
||||
output_dir (str):
|
||||
Required. The Google Cloud Storage location where the output is to
|
||||
be written to. In the given directory a new directory will be
|
||||
created with name:
|
||||
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
|
||||
where timestamp is in YYYYMMDDHHMMSS format. All export
|
||||
output will be written into that directory. Inside that
|
||||
directory, annotations with the same schema will be grouped
|
||||
into sub directories which are named with the corresponding
|
||||
annotations' schema title. Inside these sub directories, a
|
||||
schema.yaml will be created to describe the output format.
|
||||
|
||||
If the uri doesn't end with '/', a '/' will be automatically
|
||||
appended. The directory is created if it doesn't exist.
|
||||
annotation_filter (str):
|
||||
Optional. An expression for filtering what part of the Dataset
|
||||
is to be exported.
|
||||
Only Annotations that match this filter will be exported.
|
||||
The filter syntax is the same as in
|
||||
[ListAnnotations][DatasetService.ListAnnotations].
|
||||
saved_query_id (str):
|
||||
Optional. The ID of a SavedQuery (annotation set) under this
|
||||
Dataset used for filtering Annotations for training.
|
||||
|
||||
Only used for custom training data export use cases.
|
||||
Only applicable to Datasets that have SavedQueries.
|
||||
|
||||
Only Annotations that are associated with this SavedQuery are
|
||||
used in respectively training. When used in conjunction with
|
||||
annotations_filter, the Annotations used for training are
|
||||
filtered by both saved_query_id and annotations_filter.
|
||||
|
||||
Only one of saved_query_id and annotation_schema_uri should be
|
||||
specified as both of them represent the same thing: problem
|
||||
type.
|
||||
annotation_schema_uri (str):
|
||||
Optional. The Cloud Storage URI that points to a YAML file
|
||||
describing the annotation schema. The schema is defined as an
|
||||
OpenAPI 3.0.2 Schema Object. The schema files that can be used
|
||||
here are found in
|
||||
gs://google-cloud-aiplatform/schema/dataset/annotation/, note
|
||||
that the chosen schema must be consistent with
|
||||
metadata_schema_uri of this Dataset.
|
||||
|
||||
Only used for custom training data export use cases.
|
||||
Only applicable if this Dataset that have DataItems and
|
||||
Annotations.
|
||||
|
||||
Only Annotations that both match this schema and belong to
|
||||
DataItems not ignored by the split method are used in
|
||||
respectively training, validation or test role, depending on the
|
||||
role of the DataItem they are on.
|
||||
|
||||
When used in conjunction with annotations_filter, the
|
||||
Annotations used for training are filtered by both
|
||||
annotations_filter and annotation_schema_uri.
|
||||
split (Union[Dict[str, str], Dict[str, float]]):
|
||||
The instructions how the export data should be split between the
|
||||
training, validation and test sets.
|
||||
|
||||
Returns:
|
||||
export_data_response (Dict):
|
||||
Response message for DatasetService.ExportData in Dictionary
|
||||
format.
|
||||
"""
|
||||
split = self._validate_and_convert_export_split(split)
|
||||
|
||||
return json_format.MessageToDict(
|
||||
self._get_completed_export_data_operation(
|
||||
output_dir,
|
||||
gca_dataset.ExportDataConfig.ExportUse.CUSTOM_CODE_TRAINING,
|
||||
annotation_filter,
|
||||
saved_query_id,
|
||||
annotation_schema_uri,
|
||||
split,
|
||||
)._pb
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
display_name: Optional[str] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
description: Optional[str] = None,
|
||||
update_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Update the dataset.
|
||||
Updatable fields:
|
||||
- ``display_name``
|
||||
- ``description``
|
||||
- ``labels``
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
description (str):
|
||||
Optional. The description of the Dataset.
|
||||
update_request_timeout (float):
|
||||
Optional. The timeout for the update request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Updated dataset.
|
||||
"""
|
||||
|
||||
update_mask = field_mask_pb2.FieldMask()
|
||||
if display_name:
|
||||
update_mask.paths.append("display_name")
|
||||
|
||||
if labels:
|
||||
update_mask.paths.append("labels")
|
||||
|
||||
if description:
|
||||
update_mask.paths.append("description")
|
||||
|
||||
update_dataset = gca_dataset.Dataset(
|
||||
name=self.resource_name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
self._gca_resource = self.api_client.update_dataset(
|
||||
dataset=update_dataset,
|
||||
update_mask=update_mask,
|
||||
timeout=update_request_timeout,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> List[base.VertexAiResourceNoun]:
|
||||
"""List all instances of this Dataset resource.
|
||||
|
||||
Example Usage:
|
||||
|
||||
aiplatform.TabularDataset.list(
|
||||
filter='labels.my_key="my_value"',
|
||||
order_by='display_name'
|
||||
)
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. An expression for filtering the results of the request.
|
||||
For field names both snake_case and camelCase are supported.
|
||||
order_by (str):
|
||||
Optional. A comma-separated list of fields to order by, sorted in
|
||||
ascending order. Use "desc" after a field name for descending.
|
||||
Supported fields: `display_name`, `create_time`, `update_time`
|
||||
project (str):
|
||||
Optional. Project to retrieve list from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve list from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve list. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
List[base.VertexAiResourceNoun] - A list of Dataset resource objects
|
||||
"""
|
||||
|
||||
dataset_subclass_filter = (
|
||||
lambda gapic_obj: gapic_obj.metadata_schema_uri
|
||||
in cls._supported_metadata_schema_uris
|
||||
)
|
||||
|
||||
return cls._list_with_local_order(
|
||||
cls_filter=dataset_subclass_filter,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
@@ -0,0 +1,198 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class ImageDataset(datasets._Dataset):
|
||||
"""A managed image dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed image dataset. To create a managed
|
||||
image dataset, you need a datasource file in CSV format and a schema file in
|
||||
YAML format. A schema is optional for a custom model. You put the CSV file
|
||||
and the schema into Cloud Storage buckets.
|
||||
|
||||
Use image data for the following objectives:
|
||||
|
||||
* Single-label classification. For more information, see
|
||||
[Prepare image training data for single-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#single-label-classification).
|
||||
* Multi-label classification. For more information, see [Prepare image training data for multi-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#multi-label-classification).
|
||||
* Object detection. For more information, see [Prepare image training data
|
||||
for object detection](https://cloud.google.com/vertex-ai/docs/image-data/object-detection/prepare-data).
|
||||
|
||||
The following code shows you how to create an image dataset by importing data from
|
||||
a CSV datasource file and a YAML schema file. The schema file you use
|
||||
depends on whether your image dataset is used for single-label
|
||||
classification, multi-label classification, or object detection.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.ImageDataset.create(
|
||||
display_name="my-image-dataset",
|
||||
gcs_source=['gs://path/to/my/image-dataset.csv'],
|
||||
import_schema_uri=['gs://path/to/my/schema.yaml']
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.image,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "ImageDataset":
|
||||
"""Creates a new image dataset.
|
||||
|
||||
Optionally imports data into the dataset when a source and
|
||||
`import_schema_uri` are passed in.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets
|
||||
that contain your datasets. For example, `str:
|
||||
"gs://bucket/file.csv"` or `Sequence[str]:
|
||||
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
Optional. A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each image in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Images and documents are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`ImageDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the
|
||||
`ImageDataset`. These credentials override the credentials set
|
||||
by `aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this image dataset and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates an image dataset
|
||||
synchronously. If `false`, the `create` method creates an image
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
image_dataset (ImageDataset):
|
||||
An instantiated representation of the managed `ImageDataset`
|
||||
resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.image
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
@@ -0,0 +1,318 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import bigquery
|
||||
|
||||
_AUTOML_TRAINING_MIN_ROWS = 1000
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class TabularDataset(datasets._ColumnNamesDataset):
|
||||
"""A managed tabular dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with tabular datasets. You can use a CSV file, BigQuery, or a pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
|
||||
to create a tabular dataset. For more information about paging through
|
||||
BigQuery data, see [Read data with BigQuery API using
|
||||
pagination](https://cloud.google.com/bigquery/docs/paging-results). For more
|
||||
information about tabular data, see [Tabular
|
||||
data](https://cloud.google.com/vertex-ai/docs/training-overview#tabular_data).
|
||||
|
||||
The following code shows you how to create and import a tabular
|
||||
dataset with a CSV file.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TabularDataset.create(
|
||||
display_name="my-dataset", gcs_source=['gs://path/to/my/dataset.csv'])
|
||||
```
|
||||
Contrary to unstructured datasets, creating and importing a tabular dataset
|
||||
can only be done in a single step.
|
||||
|
||||
If you create a tabular dataset with a pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html),
|
||||
you need to use a BigQuery table to stage the data for Vertex AI:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TabularDataset.create_from_dataframe(
|
||||
df_source=my_pandas_dataframe,
|
||||
staging_path=f"bq://{bq_dataset_id}.table-unique"
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.tabular,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TabularDataset":
|
||||
"""Creates a tabular dataset.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`. Either `gcs_source` or `bq_source` must be specified.
|
||||
bq_source (str):
|
||||
Optional. The URI to a BigQuery table that's used as an input source. For
|
||||
example, `bq://project.dataset.table_name`. Either `gcs_source`
|
||||
or `bq_source` must be specified.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`TabularDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the `TabularDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `TabularDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a tabular dataset
|
||||
synchronously. If `false`, the `create` method creates a tabular
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
tabular_dataset (TabularDataset):
|
||||
An instantiated representation of the managed `TabularDataset` resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.tabular
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_from_dataframe(
|
||||
cls,
|
||||
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
|
||||
staging_path: str,
|
||||
bq_schema: Optional[Union[str, "bigquery.SchemaField"]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "TabularDataset":
|
||||
"""Creates a new tabular dataset from a pandas `DataFrame`.
|
||||
|
||||
Args:
|
||||
df_source (pd.DataFrame):
|
||||
Required. A pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
|
||||
containing the source data for ingestion as a `TabularDataset`.
|
||||
This method uses the data types from the provided `DataFrame`
|
||||
when the `TabularDataset` is created.
|
||||
staging_path (str):
|
||||
Required. The BigQuery table used to stage the data for Vertex
|
||||
AI. Because Vertex AI maintains a reference to this source to
|
||||
create the `TabularDataset`, you shouldn't delete this BigQuery
|
||||
table. For example: `bq://my-project.my-dataset.my-table`.
|
||||
If the specified BigQuery table doesn't exist, then the table is
|
||||
created for you. If the provided BigQuery table already exists,
|
||||
and the schemas of the BigQuery table and your DataFrame match,
|
||||
then the data in your local `DataFrame` is appended to the table.
|
||||
The location of the BigQuery table must conform to the
|
||||
[BigQuery location requirements](https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations).
|
||||
bq_schema (Optional[Union[str, bigquery.SchemaField]]):
|
||||
Optional. If not set, BigQuery autodetects the schema using the
|
||||
column types of your `DataFrame`. If set, BigQuery uses the
|
||||
schema you provide when the staging table is created. For more
|
||||
information,
|
||||
see the BigQuery
|
||||
[`LoadJobConfig.schema`](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema)
|
||||
property.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the `Dataset`. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
project (str):
|
||||
Optional. The project to upload this dataset to. This overrides
|
||||
the project set using `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The location to upload this dataset to. This overrides
|
||||
the location set using `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The custom credentials used to upload this dataset.
|
||||
This overrides credentials set using `aiplatform.init`.
|
||||
Returns:
|
||||
tabular_dataset (TabularDataset):
|
||||
An instantiated representation of the managed `TabularDataset` resource.
|
||||
"""
|
||||
|
||||
if staging_path.startswith("bq://"):
|
||||
bq_staging_path = staging_path[len("bq://") :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
|
||||
)
|
||||
|
||||
try:
|
||||
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pyarrow is not installed, and is required to use the BigQuery client."
|
||||
'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
|
||||
)
|
||||
import pandas.api.types as pd_types
|
||||
|
||||
if any(
|
||||
[
|
||||
pd_types.is_datetime64_any_dtype(df_source[column])
|
||||
for column in df_source.columns
|
||||
]
|
||||
):
|
||||
_LOGGER.info(
|
||||
"Received datetime-like column in the dataframe. Please note that the column could be interpreted differently in BigQuery depending on which major version you are using. For more information, please reference the BigQuery v3 release notes here: https://github.com/googleapis/python-bigquery/releases/tag/v3.0.0"
|
||||
)
|
||||
|
||||
if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
|
||||
_LOGGER.info(
|
||||
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
|
||||
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
|
||||
)
|
||||
|
||||
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
|
||||
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
|
||||
|
||||
bigquery_client = bigquery.Client(
|
||||
project=project or initializer.global_config.project,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
)
|
||||
|
||||
try:
|
||||
parquet_options = bigquery.format_options.ParquetOptions()
|
||||
parquet_options.enable_list_inference = True
|
||||
|
||||
job_config = bigquery.LoadJobConfig(
|
||||
source_format=bigquery.SourceFormat.PARQUET,
|
||||
parquet_options=parquet_options,
|
||||
)
|
||||
|
||||
if bq_schema:
|
||||
job_config.schema = bq_schema
|
||||
|
||||
job = bigquery_client.load_table_from_dataframe(
|
||||
dataframe=df_source, destination=bq_staging_path, job_config=job_config
|
||||
)
|
||||
|
||||
job.result()
|
||||
|
||||
finally:
|
||||
dataset_from_dataframe = cls.create(
|
||||
display_name=display_name,
|
||||
bq_source=staging_path,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return dataset_from_dataframe
|
||||
|
||||
def import_data(self):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} class does not support 'import_data'"
|
||||
)
|
||||
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class TextDataset(datasets._Dataset):
|
||||
"""A managed text dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed text dataset. To create a managed
|
||||
text dataset, you need a datasource file in CSV format and a schema file in
|
||||
YAML format. A schema is optional for a custom model. The CSV file and the
|
||||
schema are accessed in Cloud Storage buckets.
|
||||
|
||||
Use text data for the following objectives:
|
||||
|
||||
* Classification. For more information, see
|
||||
[Prepare text training data for classification](https://cloud.google.com/vertex-ai/docs/text-data/classification/prepare-data).
|
||||
* Entity extraction. For more information, see
|
||||
[Prepare text training data for entity extraction](https://cloud.google.com/vertex-ai/docs/text-data/entity-extraction/prepare-data).
|
||||
* Sentiment analysis. For more information, see
|
||||
[Prepare text training data for sentiment analysis](Prepare text training data for sentiment analysis).
|
||||
|
||||
The following code shows you how to create and import a text dataset with
|
||||
a CSV datasource file and a YAML schema file. The schema file you use
|
||||
depends on whether your text dataset is used for single-label
|
||||
classification, multi-label classification, or object detection.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TextDataset.create(
|
||||
display_name="my-text-dataset",
|
||||
gcs_source=['gs://path/to/my/text-dataset.csv'],
|
||||
import_schema_uri=['gs://path/to/my/schema.yaml'],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.text,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TextDataset":
|
||||
"""Creates a new text dataset.
|
||||
|
||||
Optionally imports data into this dataset when a source and
|
||||
`import_schema_uri` are passed in. The following is an example of how
|
||||
this method is used:
|
||||
|
||||
```py
|
||||
ds = aiplatform.TextDataset.create(
|
||||
display_name='my-dataset',
|
||||
gcs_source='gs://my-bucket/dataset.csv',
|
||||
import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets
|
||||
that contain your datasets. For example, `str:
|
||||
"gs://bucket/file.csv"` or `Sequence[str]:
|
||||
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
Optional. A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each item in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Data items are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`TextDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the `TextDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `TextDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a text dataset
|
||||
synchronously. If `false`, the `create` method creates a text
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
text_dataset (TextDataset):
|
||||
An instantiated representation of the managed `TextDataset`
|
||||
resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.text
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class TimeSeriesDataset(datasets._ColumnNamesDataset):
|
||||
"""A managed time series dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with time series datasets. A time series is a dataset
|
||||
that contains data recorded at different time intervals. The dataset
|
||||
includes time and at least one variable that's dependent on time. You use a
|
||||
time series dataset for forecasting predictions. For more information, see
|
||||
[Forecasting overview](https://cloud.google.com/vertex-ai/docs/tabular-data/forecasting/overview).
|
||||
|
||||
You can create a managed time series dataset from CSV files in a Cloud
|
||||
Storage bucket or from a BigQuery table.
|
||||
|
||||
The following code shows you how to create a `TimeSeriesDataset` with a CSV
|
||||
file that has the time series dataset:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TimeSeriesDataset.create(
|
||||
display_name="my-dataset",
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
)
|
||||
```
|
||||
|
||||
The following code shows you how to create with a `TimeSeriesDataset` with a
|
||||
BigQuery table file that has the time series dataset:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TimeSeriesDataset.create(
|
||||
display_name="my-dataset",
|
||||
bq_source=['bq://path/to/my/bigquerydataset.train'],
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.time_series,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TimeSeriesDataset":
|
||||
"""Creates a new time series dataset.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`.
|
||||
bq_source (str):
|
||||
A BigQuery URI for the input table. For example,
|
||||
`bq://project.dataset.table_name`.
|
||||
project (str):
|
||||
The name of the Google Cloud project to which this
|
||||
`TimeSeriesDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
The credentials that are used to upload the `TimeSeriesDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this time series dataset
|
||||
and all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a time series dataset
|
||||
synchronously. If `false`, the `create` method creates a time
|
||||
series dataset asynchronously.
|
||||
|
||||
Returns:
|
||||
time_series_dataset (TimeSeriesDataset):
|
||||
An instantiated representation of the managed
|
||||
`TimeSeriesDataset` resource.
|
||||
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.time_series
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
def import_data(self):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} class does not support 'import_data'"
|
||||
)
|
||||
@@ -0,0 +1,199 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class VideoDataset(datasets._Dataset):
|
||||
"""A managed video dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed video dataset. To create a video
|
||||
dataset, you need a datasource in CSV format and a schema in YAML format.
|
||||
The CSV file and the schema are accessed in Cloud Storage buckets.
|
||||
|
||||
Use video data for the following objectives:
|
||||
|
||||
Classification. For more information, see Classification schema files.
|
||||
Action recognition. For more information, see Action recognition schema
|
||||
files. Object tracking. For more information, see Object tracking schema
|
||||
files. The following code shows you how to create and import a dataset to
|
||||
train a video classification model. The schema file you use depends on
|
||||
whether you use your video dataset for action classification, recognition,
|
||||
or object tracking.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.VideoDataset.create(
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.video,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "VideoDataset":
|
||||
"""Creates a new video dataset.
|
||||
|
||||
Optionally imports data into the dataset when a source and
|
||||
`import_schema_uri` are passed in. The following is an example of how
|
||||
this method is used:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.VideoDataset.create(
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each item in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Dataset items are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
The name of the Google Cloud project to which this
|
||||
`VideoDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
The credentials that are used to upload the `VideoDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `VideoDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a video dataset
|
||||
synchronously. If `false`, the `create` mdthod creates a video
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
Returns:
|
||||
video_dataset (VideoDataset):
|
||||
An instantiated representation of the managed
|
||||
`VideoDataset` resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.video
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
Reference in New Issue
Block a user