structure saas with tools
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,291 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Base class for working with Model Garden models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Optional, Type, TypeVar
|
||||
from google.auth import exceptions as auth_exceptions
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer as aiplatform_initializer
|
||||
from google.cloud.aiplatform import models as aiplatform_models
|
||||
from google.cloud.aiplatform import _publisher_models
|
||||
|
||||
_SUPPORTED_PUBLISHERS = ["google"]
|
||||
|
||||
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
|
||||
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
|
||||
"text-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
|
||||
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
|
||||
"code-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
|
||||
"chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
|
||||
"chat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
|
||||
"codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
|
||||
"codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
|
||||
"textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
|
||||
"textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
|
||||
"text-embedding-004": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
|
||||
"text-embedding-005": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.4",
|
||||
"text-multilingual-embedding-002": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
|
||||
}
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="_ModelGardenModel")
|
||||
|
||||
# When this module is initialized, _SUBCLASSES contains a mapping of SDK class to the Model Garden instance for that class.
|
||||
# The key is the SDK class since multiple classes can share a schema URI (i.e. _PreviewTextGenerationModel and TextGenerationModel)
|
||||
# For example: {"<class 'google.cloud.aiplatform.vertexai.language_models._language_models._TextGenerationModel'>: gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"}
|
||||
_SUBCLASSES = {}
|
||||
|
||||
|
||||
def _get_model_class_from_schema_uri(
|
||||
schema_uri: str,
|
||||
) -> "_ModelGardenModel":
|
||||
"""Gets the _ModelGardenModel class for the provided PublisherModel schema uri.
|
||||
|
||||
Args:
|
||||
schema_uri (str): The schema_uri for the provided PublisherModel, for example:
|
||||
"gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"
|
||||
|
||||
Returns:
|
||||
The _ModelGardenModel class associated with the provided schema uri.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
If the provided PublisherModel schema_uri isn't supported by the SDK in Preview.
|
||||
"""
|
||||
|
||||
for sdk_class in _SUBCLASSES:
|
||||
class_schema_uri = _SUBCLASSES[sdk_class]
|
||||
if class_schema_uri == schema_uri and "preview" in sdk_class.__module__:
|
||||
return sdk_class
|
||||
|
||||
raise ValueError("This model is not supported in Preview by the Vertex SDK.")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ModelInfo:
|
||||
endpoint_name: str
|
||||
interface_class: Type["_ModelGardenModel"]
|
||||
publisher_model_resource: _publisher_models._PublisherModel
|
||||
tuning_pipeline_uri: Optional[str] = None
|
||||
tuning_model_id: Optional[str] = None
|
||||
|
||||
|
||||
def _get_model_info(
|
||||
model_id: str,
|
||||
schema_to_class_map: Optional[Dict[str, "_ModelGardenModel"]] = None,
|
||||
interface_class: Optional[Type["_ModelGardenModel"]] = None,
|
||||
publisher_model_res: Optional[_publisher_models._PublisherModel] = None,
|
||||
tuned_vertex_model: Optional[aiplatform.Model] = None,
|
||||
) -> _ModelInfo:
|
||||
"""Gets the model information by model ID.
|
||||
|
||||
Args:
|
||||
model_id (str):
|
||||
Identifier of a Model Garden Model. Example: "text-bison@001"
|
||||
schema_to_class_map (Dict[str, "_ModelGardenModel"]):
|
||||
Mapping of schema URI to model class.
|
||||
|
||||
Returns:
|
||||
_ModelInfo:
|
||||
Instance of _ModelInfo with details on the provided model_id.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If a publisher other than Google is provided in the publisher resource name
|
||||
If the provided model doesn't have an associated Publisher Model
|
||||
If the model's schema uri is not in the provided schema_to_class_map
|
||||
"""
|
||||
|
||||
# The default publisher is Google
|
||||
if "/" not in model_id:
|
||||
model_id = "publishers/google/models/" + model_id
|
||||
|
||||
if not publisher_model_res:
|
||||
publisher_model_res = (
|
||||
_publisher_models._PublisherModel( # pylint: disable=protected-access
|
||||
resource_name=model_id
|
||||
)._gca_resource
|
||||
)
|
||||
|
||||
if not publisher_model_res.name.startswith("publishers/google/models/"):
|
||||
raise ValueError(
|
||||
f"Only Google models are currently supported. {publisher_model_res.name}"
|
||||
)
|
||||
short_model_id = publisher_model_res.name.rsplit("/", 1)[-1]
|
||||
|
||||
# == "projects/{project}/locations/{location}/publishers/google/models/text-bison@001"
|
||||
publisher_model_template = publisher_model_res.publisher_model_template.replace(
|
||||
"{user-project}", "{project}"
|
||||
)
|
||||
if not publisher_model_template:
|
||||
raise RuntimeError(
|
||||
f"The model does not have an associated Publisher Model. {publisher_model_res.name}"
|
||||
)
|
||||
|
||||
if not tuned_vertex_model:
|
||||
endpoint_name = publisher_model_template.format(
|
||||
project=aiplatform_initializer.global_config.project,
|
||||
location=aiplatform_initializer.global_config.location,
|
||||
)
|
||||
else:
|
||||
tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models
|
||||
if len(tuned_model_deployments) == 0:
|
||||
# Deploying the model
|
||||
endpoint_name = tuned_vertex_model.deploy().resource_name
|
||||
else:
|
||||
endpoint_name = tuned_model_deployments[0].endpoint
|
||||
|
||||
if short_model_id in _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP:
|
||||
tuning_pipeline_uri = _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP[short_model_id]
|
||||
tuning_model_id = publisher_model_template.rsplit("/", 1)[-1]
|
||||
else:
|
||||
tuning_pipeline_uri = None
|
||||
tuning_model_id = None
|
||||
|
||||
if schema_to_class_map:
|
||||
interface_class = schema_to_class_map.get(
|
||||
publisher_model_res.predict_schemata.instance_schema_uri
|
||||
)
|
||||
|
||||
if not interface_class:
|
||||
raise ValueError(
|
||||
f"Unknown model {publisher_model_res.name}; {schema_to_class_map}"
|
||||
)
|
||||
|
||||
return _ModelInfo(
|
||||
endpoint_name=endpoint_name,
|
||||
interface_class=interface_class,
|
||||
publisher_model_resource=publisher_model_res,
|
||||
tuning_pipeline_uri=tuning_pipeline_uri,
|
||||
tuning_model_id=tuning_model_id,
|
||||
)
|
||||
|
||||
|
||||
def _from_pretrained(
|
||||
*,
|
||||
interface_class: Optional[Type[T]] = None,
|
||||
model_name: Optional[str] = None,
|
||||
publisher_model: Optional[_publisher_models._PublisherModel] = None,
|
||||
tuned_vertex_model: Optional[aiplatform.Model] = None,
|
||||
) -> T:
|
||||
"""Loads a _ModelGardenModel.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model.
|
||||
|
||||
Returns:
|
||||
An instance of a class derieved from `_ModelGardenModel`.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_name is unknown.
|
||||
ValueError: If model does not support this class.
|
||||
"""
|
||||
if interface_class:
|
||||
if not interface_class._INSTANCE_SCHEMA_URI:
|
||||
raise ValueError(
|
||||
f"Class {interface_class} is not a correct model interface class since it does not have an instance schema URI."
|
||||
)
|
||||
|
||||
model_info = _get_model_info(
|
||||
model_id=model_name,
|
||||
schema_to_class_map={interface_class._INSTANCE_SCHEMA_URI: interface_class},
|
||||
)
|
||||
|
||||
else:
|
||||
schema_uri = publisher_model._gca_resource.predict_schemata.instance_schema_uri
|
||||
interface_class = _get_model_class_from_schema_uri(schema_uri)
|
||||
|
||||
model_info = _get_model_info(
|
||||
model_id=model_name,
|
||||
interface_class=interface_class,
|
||||
publisher_model_res=publisher_model._gca_resource,
|
||||
tuned_vertex_model=tuned_vertex_model,
|
||||
)
|
||||
|
||||
if not issubclass(model_info.interface_class, interface_class):
|
||||
raise ValueError(
|
||||
f"{model_name} is of type {model_info.interface_class.__name__} not of type {interface_class.__name__}"
|
||||
)
|
||||
|
||||
return model_info.interface_class(
|
||||
model_id=model_name,
|
||||
endpoint_name=model_info.endpoint_name,
|
||||
)
|
||||
|
||||
|
||||
class _ModelGardenModel:
|
||||
"""Base class for shared methods and properties across Model Garden models."""
|
||||
|
||||
# Subclasses override this attribute to specify their instance schema
|
||||
_INSTANCE_SCHEMA_URI: Optional[str] = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
_SUBCLASSES[cls] = cls._INSTANCE_SCHEMA_URI
|
||||
|
||||
def __init__(self, model_id: str, endpoint_name: Optional[str] = None):
|
||||
"""Creates a _ModelGardenModel.
|
||||
|
||||
This constructor should not be called directly.
|
||||
Use `{model_class}.from_pretrained(model_name=...)` instead.
|
||||
|
||||
Args:
|
||||
model_id: Identifier of a Model Garden Model. Example: "text-bison@001"
|
||||
endpoint_name: Vertex Endpoint resource name for the model
|
||||
"""
|
||||
|
||||
self._model_id = model_id
|
||||
self._endpoint_name = endpoint_name
|
||||
# TODO(b/280879204)
|
||||
# A workaround for not being able to directly instantiate the
|
||||
# high-level Endpoint with the PublisherModel resource name.
|
||||
self._endpoint = aiplatform.Endpoint._construct_sdk_resource_from_gapic(
|
||||
aiplatform_models.gca_endpoint_compat.Endpoint(name=endpoint_name),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls: Type[T], model_name: str) -> T:
|
||||
"""Loads a _ModelGardenModel.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model.
|
||||
|
||||
Returns:
|
||||
An instance of a class derieved from `_ModelGardenModel`.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_name is unknown.
|
||||
ValueError: If model does not support this class.
|
||||
"""
|
||||
|
||||
credential_exception_str = (
|
||||
"\nUnable to authenticate your request."
|
||||
"\nDepending on your runtime environment, you can complete authentication by:"
|
||||
"\n- if in local JupyterLab instance: `!gcloud auth login` "
|
||||
"\n- if in Colab:"
|
||||
"\n -`from google.colab import auth`"
|
||||
"\n -`auth.authenticate_user()`"
|
||||
"\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication"
|
||||
)
|
||||
|
||||
try:
|
||||
return _from_pretrained(interface_class=cls, model_name=model_name)
|
||||
except auth_exceptions.GoogleAuthError as e:
|
||||
raise auth_exceptions.GoogleAuthError(credential_exception_str) from e
|
||||
Reference in New Issue
Block a user