structure saas with tools
This commit is contained in:
349
.venv/lib/python3.10/site-packages/vertexai/tuning/_tuning.py
Normal file
349
.venv/lib/python3.10/site-packages/vertexai/tuning/_tuning.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# pylint: disable=protected-access
|
||||
"""Classes to support Tuning."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import base as aiplatform_base
|
||||
from google.cloud.aiplatform import compat
|
||||
from google.cloud.aiplatform import initializer as aiplatform_initializer
|
||||
from google.cloud.aiplatform import jobs
|
||||
from google.cloud.aiplatform import utils as aiplatform_utils
|
||||
from google.cloud.aiplatform.utils import _ipython_utils
|
||||
from google.cloud.aiplatform_v1.services import (
|
||||
gen_ai_tuning_service as gen_ai_tuning_service_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services import (
|
||||
gen_ai_tuning_service as gen_ai_tuning_service_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
tuning_job as gca_tuning_job_types,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1 import types as gca_types
|
||||
|
||||
from google.rpc import status_pb2 # type: ignore
|
||||
|
||||
|
||||
_LOGGER = aiplatform_base.Logger(__name__)
|
||||
|
||||
|
||||
class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
|
||||
_is_temporary = True
|
||||
_default_version = compat.V1BETA1
|
||||
_version_map = (
|
||||
(compat.V1, gen_ai_tuning_service_v1.client.GenAiTuningServiceClient),
|
||||
(compat.V1BETA1, gen_ai_tuning_service_v1beta1.client.GenAiTuningServiceClient),
|
||||
)
|
||||
|
||||
|
||||
class TuningJob(aiplatform_base._VertexAiResourceNounPlus):
|
||||
"""Represents a TuningJob that runs with Google owned models."""
|
||||
|
||||
_resource_noun = "tuningJobs"
|
||||
_getter_method = "get_tuning_job"
|
||||
_list_method = "list_tuning_jobs"
|
||||
_cancel_method = "cancel_tuning_job"
|
||||
_delete_method = "delete_tuning_job"
|
||||
_parse_resource_name_method = "parse_tuning_job_path"
|
||||
_format_resource_name_method = "tuning_job_path"
|
||||
_job_type = "tuning/tuningJob"
|
||||
_has_displayed_experiments_button = False
|
||||
|
||||
client_class = TuningJobClientWithOverride
|
||||
|
||||
_gca_resource: gca_tuning_job_types.TuningJob
|
||||
api_client: gen_ai_tuning_service_v1beta1.client.GenAiTuningServiceClient
|
||||
|
||||
def __init__(self, tuning_job_name: str):
|
||||
super().__init__(resource_name=tuning_job_name)
|
||||
self._gca_resource: gca_tuning_job_types.TuningJob = self._get_gca_resource(
|
||||
resource_name=tuning_job_name
|
||||
)
|
||||
|
||||
def refresh(self) -> "TuningJob":
|
||||
"""Refreshed the tuning job from the service."""
|
||||
self._gca_resource: gca_tuning_job_types.TuningJob = self._get_gca_resource(
|
||||
resource_name=self.resource_name
|
||||
)
|
||||
if self.experiment and not self._has_displayed_experiments_button:
|
||||
self._has_displayed_experiments_button = True
|
||||
_ipython_utils.display_experiment_button(self.experiment)
|
||||
return self
|
||||
|
||||
@property
|
||||
def tuned_model_name(self) -> Optional[str]:
|
||||
return self._gca_resource.tuned_model.model
|
||||
|
||||
@property
|
||||
def tuned_model_endpoint_name(self) -> Optional[str]:
|
||||
return self._gca_resource.tuned_model.endpoint
|
||||
|
||||
@property
|
||||
def _experiment_name(self) -> Optional[str]:
|
||||
return self._gca_resource.experiment
|
||||
|
||||
@property
|
||||
def experiment(self) -> Optional[aiplatform.Experiment]:
|
||||
if self._experiment_name:
|
||||
return aiplatform.Experiment(experiment_name=self._experiment_name)
|
||||
|
||||
@property
|
||||
def state(self) -> gca_types.JobState:
|
||||
return self._gca_resource.state
|
||||
|
||||
@property
|
||||
def service_account(self) -> Optional[str]:
|
||||
self._assert_gca_resource_is_available()
|
||||
return self._gca_resource.service_account
|
||||
|
||||
@property
|
||||
def has_ended(self):
|
||||
return self.state in jobs._JOB_COMPLETE_STATES
|
||||
|
||||
@property
|
||||
def has_succeeded(self):
|
||||
return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[status_pb2.Status]:
|
||||
return self._gca_resource.error
|
||||
|
||||
@property
|
||||
def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
|
||||
return self._gca_resource.tuning_data_stats
|
||||
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
*,
|
||||
base_model: str,
|
||||
tuning_spec: Union[
|
||||
gca_tuning_job_types.SupervisedTuningSpec,
|
||||
gca_tuning_job_types.DistillationSpec,
|
||||
],
|
||||
tuned_model_display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "TuningJob":
|
||||
r"""Submits TuningJob.
|
||||
|
||||
Args:
|
||||
base_model (str):
|
||||
Model name for tuning, e.g., "gemini-1.0-pro"
|
||||
or "gemini-1.0-pro-001".
|
||||
|
||||
This field is a member of `oneof`_ ``source_model``.
|
||||
tuning_spec: Tuning Spec for Fine Tuning.
|
||||
Supported types: SupervisedTuningSpec, DistillationSpec.
|
||||
tuned_model_display_name: The display name of the
|
||||
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
|
||||
be up to 128 characters long and can consist of any UTF-8
|
||||
characters.
|
||||
description: The description of the `TuningJob`.
|
||||
labels: The labels with user-defined metadata to organize
|
||||
[TuningJob][google.cloud.aiplatform.v1.TuningJob] and
|
||||
generated resources such as
|
||||
[Model][google.cloud.aiplatform.v1.Model] and
|
||||
[Endpoint][google.cloud.aiplatform.v1.Endpoint].
|
||||
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters,
|
||||
numeric characters, underscores and dashes. International
|
||||
characters are allowed.
|
||||
|
||||
See https://goo.gl/xmQnxf for more information and examples
|
||||
of labels.
|
||||
project: Project to run the tuning job in.
|
||||
Overrides project set in aiplatform.init.
|
||||
location: Location to run the tuning job in.
|
||||
Overrides location set in aiplatform.init.
|
||||
credentials: Custom credentials to use to call tuning job service.
|
||||
Overrides credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Submitted TuningJob.
|
||||
|
||||
Raises:
|
||||
RuntimeError is tuning_spec kind is unsupported
|
||||
"""
|
||||
_LOGGER.log_create_with_lro(cls)
|
||||
|
||||
if not tuned_model_display_name:
|
||||
tuned_model_display_name = cls._generate_display_name()
|
||||
|
||||
gca_tuning_job = gca_tuning_job_types.TuningJob(
|
||||
base_model=base_model,
|
||||
tuned_model_display_name=tuned_model_display_name,
|
||||
description=description,
|
||||
labels=labels,
|
||||
# The tuning_spec one_of is set later
|
||||
)
|
||||
|
||||
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
|
||||
gca_tuning_job.supervised_tuning_spec = tuning_spec
|
||||
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
|
||||
gca_tuning_job.distillation_spec = tuning_spec
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
|
||||
|
||||
if aiplatform_initializer.global_config.encryption_spec_key_name:
|
||||
gca_tuning_job.encryption_spec.kms_key_name = (
|
||||
aiplatform_initializer.global_config.encryption_spec_key_name
|
||||
)
|
||||
gca_tuning_job.service_account = (
|
||||
aiplatform_initializer.global_config.service_account
|
||||
)
|
||||
|
||||
tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
|
||||
gapic_resource=gca_tuning_job,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
parent = aiplatform_initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
)
|
||||
|
||||
created_gca_tuning_job = tuning_job.api_client.create_tuning_job(
|
||||
parent=parent,
|
||||
tuning_job=gca_tuning_job,
|
||||
)
|
||||
tuning_job._gca_resource = created_gca_tuning_job
|
||||
|
||||
_LOGGER.log_create_complete(
|
||||
cls=cls,
|
||||
resource=created_gca_tuning_job,
|
||||
variable_name="tuning_job",
|
||||
module_name="sft",
|
||||
)
|
||||
_LOGGER.info(f"View Tuning Job:\n{tuning_job._dashboard_url()}")
|
||||
if tuning_job._experiment_name:
|
||||
_LOGGER.info(f"View experiment:\n{tuning_job._experiment.dashboard_url}")
|
||||
|
||||
return tuning_job
|
||||
|
||||
def cancel(self):
|
||||
self.api_client.cancel_tuning_job(name=self.resource_name)
|
||||
|
||||
@classmethod
|
||||
def list(cls, filter: Optional[str] = None) -> List["TuningJob"]:
|
||||
"""Lists TuningJobs.
|
||||
|
||||
Args:
|
||||
filter: The standard list filter.
|
||||
|
||||
Returns:
|
||||
A list of TuningJob objects.
|
||||
"""
|
||||
return cls._list(filter=filter)
|
||||
|
||||
def _dashboard_url(self) -> str:
|
||||
"""Returns the Google Cloud console URL where job can be viewed."""
|
||||
fields = self._parse_resource_name(self.resource_name)
|
||||
location = fields.pop("location")
|
||||
project = fields.pop("project")
|
||||
job = list(fields.values())[0]
|
||||
url = f"https://console.cloud.google.com/vertex-ai/generative/language/locations/{location}/tuning/tuningJob/{job}?project={project}"
|
||||
return url
|
||||
|
||||
|
||||
def rebase_tuned_model(
|
||||
tuned_model_ref: str,
|
||||
*,
|
||||
# TODO(b/372291558): Add support for overriding tuning job config
|
||||
# tuning_job_config: Optional["TuningJob"] = None,
|
||||
artifact_destination: Optional[str] = None,
|
||||
deploy_to_same_endpoint: Optional[bool] = False,
|
||||
):
|
||||
"""Re-runs fine tuning on top of a new foundational model.
|
||||
|
||||
Takes a legacy Tuned GenAI model Reference and creates a TuningJob based
|
||||
on a new model.
|
||||
|
||||
Args:
|
||||
tuned_model_ref: Required. TunedModel reference to retrieve
|
||||
the legacy model information.
|
||||
tuning_job_config: The TuningJob to be updated. Users
|
||||
can use this TuningJob field to overwrite tuning
|
||||
configs.
|
||||
artifact_destination: The Google Cloud Storage location to write the artifacts.
|
||||
deploy_to_same_endpoint:
|
||||
Optional. By default, bison to gemini
|
||||
migration will always create new model/endpoint,
|
||||
but for gemini-1.0 to gemini-1.5 migration, we
|
||||
default deploy to the same endpoint. See details
|
||||
in this Section.
|
||||
|
||||
Returns:
|
||||
The new TuningJob.
|
||||
"""
|
||||
parent = aiplatform_initializer.global_config.common_location_path(
|
||||
project=aiplatform_initializer.global_config.project,
|
||||
location=aiplatform_initializer.global_config.location,
|
||||
)
|
||||
|
||||
if "/tuningJobs/" in tuned_model_ref:
|
||||
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
||||
tuning_job=tuned_model_ref,
|
||||
)
|
||||
elif "/pipelineJobs/" in tuned_model_ref:
|
||||
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
||||
pipeline_job=tuned_model_ref,
|
||||
)
|
||||
elif "/models/" in tuned_model_ref:
|
||||
gapic_tuned_model_ref = gca_types.TunedModelRef(
|
||||
tuned_model=tuned_model_ref,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tuned_model_ref: {tuned_model_ref}.")
|
||||
|
||||
# gapic_tuning_job_config = tuning_job._gca_resource if tuning_job else None
|
||||
gapic_tuning_job_config = None
|
||||
|
||||
gapic_artifact_destination = (
|
||||
gca_types.GcsDestination(output_uri_prefix=artifact_destination)
|
||||
if artifact_destination
|
||||
else None
|
||||
)
|
||||
|
||||
api_client: gen_ai_tuning_service_v1beta1.GenAiTuningServiceClient = (
|
||||
TuningJob._instantiate_client(
|
||||
location=aiplatform_initializer.global_config.location,
|
||||
credentials=aiplatform_initializer.global_config.credentials,
|
||||
)
|
||||
)
|
||||
rebase_operation = api_client.rebase_tuned_model(
|
||||
gca_types.RebaseTunedModelRequest(
|
||||
parent=parent,
|
||||
tuned_model_ref=gapic_tuned_model_ref,
|
||||
tuning_job=gapic_tuning_job_config,
|
||||
artifact_destination=gapic_artifact_destination,
|
||||
deploy_to_same_endpoint=deploy_to_same_endpoint,
|
||||
)
|
||||
)
|
||||
_LOGGER.log_create_with_lro(TuningJob, lro=rebase_operation)
|
||||
gapic_rebase_tuning_job: gca_types.TuningJob = rebase_operation.result()
|
||||
rebase_tuning_job = TuningJob._construct_sdk_resource_from_gapic(
|
||||
gapic_resource=gapic_rebase_tuning_job,
|
||||
)
|
||||
return rebase_tuning_job
|
||||
Reference in New Issue
Block a user