350 lines
13 KiB
Python
350 lines
13 KiB
Python
# Copyright 2024 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# pylint: disable=protected-access
|
|
"""Classes to support Tuning."""
|
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
from google.auth import credentials as auth_credentials
|
|
|
|
from google.cloud import aiplatform
|
|
from google.cloud.aiplatform import base as aiplatform_base
|
|
from google.cloud.aiplatform import compat
|
|
from google.cloud.aiplatform import initializer as aiplatform_initializer
|
|
from google.cloud.aiplatform import jobs
|
|
from google.cloud.aiplatform import utils as aiplatform_utils
|
|
from google.cloud.aiplatform.utils import _ipython_utils
|
|
from google.cloud.aiplatform_v1.services import (
|
|
gen_ai_tuning_service as gen_ai_tuning_service_v1,
|
|
)
|
|
from google.cloud.aiplatform_v1beta1.services import (
|
|
gen_ai_tuning_service as gen_ai_tuning_service_v1beta1,
|
|
)
|
|
from google.cloud.aiplatform_v1beta1.types import (
|
|
tuning_job as gca_tuning_job_types,
|
|
)
|
|
from google.cloud.aiplatform_v1beta1 import types as gca_types
|
|
|
|
from google.rpc import status_pb2 # type: ignore
|
|
|
|
|
|
_LOGGER = aiplatform_base.Logger(__name__)
|
|
|
|
|
|
class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
|
|
_is_temporary = True
|
|
_default_version = compat.V1BETA1
|
|
_version_map = (
|
|
(compat.V1, gen_ai_tuning_service_v1.client.GenAiTuningServiceClient),
|
|
(compat.V1BETA1, gen_ai_tuning_service_v1beta1.client.GenAiTuningServiceClient),
|
|
)
|
|
|
|
|
|
class TuningJob(aiplatform_base._VertexAiResourceNounPlus):
|
|
"""Represents a TuningJob that runs with Google owned models."""
|
|
|
|
_resource_noun = "tuningJobs"
|
|
_getter_method = "get_tuning_job"
|
|
_list_method = "list_tuning_jobs"
|
|
_cancel_method = "cancel_tuning_job"
|
|
_delete_method = "delete_tuning_job"
|
|
_parse_resource_name_method = "parse_tuning_job_path"
|
|
_format_resource_name_method = "tuning_job_path"
|
|
_job_type = "tuning/tuningJob"
|
|
_has_displayed_experiments_button = False
|
|
|
|
client_class = TuningJobClientWithOverride
|
|
|
|
_gca_resource: gca_tuning_job_types.TuningJob
|
|
api_client: gen_ai_tuning_service_v1beta1.client.GenAiTuningServiceClient
|
|
|
|
def __init__(self, tuning_job_name: str):
|
|
super().__init__(resource_name=tuning_job_name)
|
|
self._gca_resource: gca_tuning_job_types.TuningJob = self._get_gca_resource(
|
|
resource_name=tuning_job_name
|
|
)
|
|
|
|
def refresh(self) -> "TuningJob":
|
|
"""Refreshed the tuning job from the service."""
|
|
self._gca_resource: gca_tuning_job_types.TuningJob = self._get_gca_resource(
|
|
resource_name=self.resource_name
|
|
)
|
|
if self.experiment and not self._has_displayed_experiments_button:
|
|
self._has_displayed_experiments_button = True
|
|
_ipython_utils.display_experiment_button(self.experiment)
|
|
return self
|
|
|
|
@property
|
|
def tuned_model_name(self) -> Optional[str]:
|
|
return self._gca_resource.tuned_model.model
|
|
|
|
@property
|
|
def tuned_model_endpoint_name(self) -> Optional[str]:
|
|
return self._gca_resource.tuned_model.endpoint
|
|
|
|
@property
|
|
def _experiment_name(self) -> Optional[str]:
|
|
return self._gca_resource.experiment
|
|
|
|
@property
|
|
def experiment(self) -> Optional[aiplatform.Experiment]:
|
|
if self._experiment_name:
|
|
return aiplatform.Experiment(experiment_name=self._experiment_name)
|
|
|
|
@property
|
|
def state(self) -> gca_types.JobState:
|
|
return self._gca_resource.state
|
|
|
|
@property
|
|
def service_account(self) -> Optional[str]:
|
|
self._assert_gca_resource_is_available()
|
|
return self._gca_resource.service_account
|
|
|
|
@property
|
|
def has_ended(self):
|
|
return self.state in jobs._JOB_COMPLETE_STATES
|
|
|
|
@property
|
|
def has_succeeded(self):
|
|
return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED
|
|
|
|
@property
|
|
def error(self) -> Optional[status_pb2.Status]:
|
|
return self._gca_resource.error
|
|
|
|
@property
|
|
def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
|
|
return self._gca_resource.tuning_data_stats
|
|
|
|
@classmethod
|
|
def _create(
|
|
cls,
|
|
*,
|
|
base_model: str,
|
|
tuning_spec: Union[
|
|
gca_tuning_job_types.SupervisedTuningSpec,
|
|
gca_tuning_job_types.DistillationSpec,
|
|
],
|
|
tuned_model_display_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
labels: Optional[Dict[str, str]] = None,
|
|
project: Optional[str] = None,
|
|
location: Optional[str] = None,
|
|
credentials: Optional[auth_credentials.Credentials] = None,
|
|
) -> "TuningJob":
|
|
r"""Submits TuningJob.
|
|
|
|
Args:
|
|
base_model (str):
|
|
Model name for tuning, e.g., "gemini-1.0-pro"
|
|
or "gemini-1.0-pro-001".
|
|
|
|
This field is a member of `oneof`_ ``source_model``.
|
|
tuning_spec: Tuning Spec for Fine Tuning.
|
|
Supported types: SupervisedTuningSpec, DistillationSpec.
|
|
tuned_model_display_name: The display name of the
|
|
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
|
|
be up to 128 characters long and can consist of any UTF-8
|
|
characters.
|
|
description: The description of the `TuningJob`.
|
|
labels: The labels with user-defined metadata to organize
|
|
[TuningJob][google.cloud.aiplatform.v1.TuningJob] and
|
|
generated resources such as
|
|
[Model][google.cloud.aiplatform.v1.Model] and
|
|
[Endpoint][google.cloud.aiplatform.v1.Endpoint].
|
|
|
|
Label keys and values can be no longer than 64 characters
|
|
(Unicode codepoints), can only contain lowercase letters,
|
|
numeric characters, underscores and dashes. International
|
|
characters are allowed.
|
|
|
|
See https://goo.gl/xmQnxf for more information and examples
|
|
of labels.
|
|
project: Project to run the tuning job in.
|
|
Overrides project set in aiplatform.init.
|
|
location: Location to run the tuning job in.
|
|
Overrides location set in aiplatform.init.
|
|
credentials: Custom credentials to use to call tuning job service.
|
|
Overrides credentials set in aiplatform.init.
|
|
|
|
Returns:
|
|
Submitted TuningJob.
|
|
|
|
Raises:
|
|
RuntimeError is tuning_spec kind is unsupported
|
|
"""
|
|
_LOGGER.log_create_with_lro(cls)
|
|
|
|
if not tuned_model_display_name:
|
|
tuned_model_display_name = cls._generate_display_name()
|
|
|
|
gca_tuning_job = gca_tuning_job_types.TuningJob(
|
|
base_model=base_model,
|
|
tuned_model_display_name=tuned_model_display_name,
|
|
description=description,
|
|
labels=labels,
|
|
# The tuning_spec one_of is set later
|
|
)
|
|
|
|
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
|
|
gca_tuning_job.supervised_tuning_spec = tuning_spec
|
|
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
|
|
gca_tuning_job.distillation_spec = tuning_spec
|
|
else:
|
|
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
|
|
|
|
if aiplatform_initializer.global_config.encryption_spec_key_name:
|
|
gca_tuning_job.encryption_spec.kms_key_name = (
|
|
aiplatform_initializer.global_config.encryption_spec_key_name
|
|
)
|
|
gca_tuning_job.service_account = (
|
|
aiplatform_initializer.global_config.service_account
|
|
)
|
|
|
|
tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
|
|
gapic_resource=gca_tuning_job,
|
|
project=project,
|
|
location=location,
|
|
credentials=credentials,
|
|
)
|
|
|
|
parent = aiplatform_initializer.global_config.common_location_path(
|
|
project=project, location=location
|
|
)
|
|
|
|
created_gca_tuning_job = tuning_job.api_client.create_tuning_job(
|
|
parent=parent,
|
|
tuning_job=gca_tuning_job,
|
|
)
|
|
tuning_job._gca_resource = created_gca_tuning_job
|
|
|
|
_LOGGER.log_create_complete(
|
|
cls=cls,
|
|
resource=created_gca_tuning_job,
|
|
variable_name="tuning_job",
|
|
module_name="sft",
|
|
)
|
|
_LOGGER.info(f"View Tuning Job:\n{tuning_job._dashboard_url()}")
|
|
if tuning_job._experiment_name:
|
|
_LOGGER.info(f"View experiment:\n{tuning_job._experiment.dashboard_url}")
|
|
|
|
return tuning_job
|
|
|
|
def cancel(self):
|
|
self.api_client.cancel_tuning_job(name=self.resource_name)
|
|
|
|
@classmethod
|
|
def list(cls, filter: Optional[str] = None) -> List["TuningJob"]:
|
|
"""Lists TuningJobs.
|
|
|
|
Args:
|
|
filter: The standard list filter.
|
|
|
|
Returns:
|
|
A list of TuningJob objects.
|
|
"""
|
|
return cls._list(filter=filter)
|
|
|
|
def _dashboard_url(self) -> str:
|
|
"""Returns the Google Cloud console URL where job can be viewed."""
|
|
fields = self._parse_resource_name(self.resource_name)
|
|
location = fields.pop("location")
|
|
project = fields.pop("project")
|
|
job = list(fields.values())[0]
|
|
url = f"https://console.cloud.google.com/vertex-ai/generative/language/locations/{location}/tuning/tuningJob/{job}?project={project}"
|
|
return url
|
|
|
|
|
|
def rebase_tuned_model(
|
|
tuned_model_ref: str,
|
|
*,
|
|
# TODO(b/372291558): Add support for overriding tuning job config
|
|
# tuning_job_config: Optional["TuningJob"] = None,
|
|
artifact_destination: Optional[str] = None,
|
|
deploy_to_same_endpoint: Optional[bool] = False,
|
|
):
|
|
"""Re-runs fine tuning on top of a new foundational model.
|
|
|
|
Takes a legacy Tuned GenAI model Reference and creates a TuningJob based
|
|
on a new model.
|
|
|
|
Args:
|
|
tuned_model_ref: Required. TunedModel reference to retrieve
|
|
the legacy model information.
|
|
tuning_job_config: The TuningJob to be updated. Users
|
|
can use this TuningJob field to overwrite tuning
|
|
configs.
|
|
artifact_destination: The Google Cloud Storage location to write the artifacts.
|
|
deploy_to_same_endpoint:
|
|
Optional. By default, bison to gemini
|
|
migration will always create new model/endpoint,
|
|
but for gemini-1.0 to gemini-1.5 migration, we
|
|
default deploy to the same endpoint. See details
|
|
in this Section.
|
|
|
|
Returns:
|
|
The new TuningJob.
|
|
"""
|
|
parent = aiplatform_initializer.global_config.common_location_path(
|
|
project=aiplatform_initializer.global_config.project,
|
|
location=aiplatform_initializer.global_config.location,
|
|
)
|
|
|
|
if "/tuningJobs/" in tuned_model_ref:
|
|
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
|
tuning_job=tuned_model_ref,
|
|
)
|
|
elif "/pipelineJobs/" in tuned_model_ref:
|
|
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
|
pipeline_job=tuned_model_ref,
|
|
)
|
|
elif "/models/" in tuned_model_ref:
|
|
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
|
tuned_model=tuned_model_ref,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported tuned_model_ref: {tuned_model_ref}.")
|
|
|
|
# gapic_tuning_job_config = tuning_job._gca_resource if tuning_job else None
|
|
gapic_tuning_job_config = None
|
|
|
|
gapic_artifact_destination = (
|
|
gca_types.GcsDestination(output_uri_prefix=artifact_destination)
|
|
if artifact_destination
|
|
else None
|
|
)
|
|
|
|
api_client: gen_ai_tuning_service_v1beta1.GenAiTuningServiceClient = (
|
|
TuningJob._instantiate_client(
|
|
location=aiplatform_initializer.global_config.location,
|
|
credentials=aiplatform_initializer.global_config.credentials,
|
|
)
|
|
)
|
|
rebase_operation = api_client.rebase_tuned_model(
|
|
gca_types.RebaseTunedModelRequest(
|
|
parent=parent,
|
|
tuned_model_ref=gapic_tuned_model_ref,
|
|
tuning_job=gapic_tuning_job_config,
|
|
artifact_destination=gapic_artifact_destination,
|
|
deploy_to_same_endpoint=deploy_to_same_endpoint,
|
|
)
|
|
)
|
|
_LOGGER.log_create_with_lro(TuningJob, lro=rebase_operation)
|
|
gapic_rebase_tuning_job: gca_types.TuningJob = rebase_operation.result()
|
|
rebase_tuning_job = TuningJob._construct_sdk_resource_from_gapic(
|
|
gapic_resource=gapic_rebase_tuning_job,
|
|
)
|
|
return rebase_tuning_job
|