structure saas with tools
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer as aiplatform_initializer
|
||||
from vertexai.language_models import _language_models
|
||||
from vertexai.language_models import _language_models as tuning
|
||||
|
||||
|
||||
_DISTILLATION_PIPELINE_URI = (
|
||||
"https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
|
||||
)
|
||||
|
||||
|
||||
class DistillationMixin:
|
||||
def distill_from(
|
||||
self,
|
||||
*,
|
||||
dataset: str,
|
||||
teacher_model: Union[str, _language_models._TextGenerationModel],
|
||||
train_steps: Optional[int] = None,
|
||||
learning_rate_multiplier: Optional[float] = None,
|
||||
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
|
||||
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
|
||||
model_display_name: Optional[str] = None,
|
||||
max_context_length: Optional[str] = None,
|
||||
):
|
||||
"""Tunes a smaller model with help from another bigger model.
|
||||
|
||||
Args:
|
||||
dataset: A URI pointing to data in JSON lines format.
|
||||
teacher_model: The teacher model to use for distillation.
|
||||
train_steps: Number of training batches to use (batch size is 8 samples).
|
||||
learning_rate_multiplier: Learning rate multiplier to use in tuning.
|
||||
evaluation_spec: Specification for the model evaluation during tuning.
|
||||
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
|
||||
model_display_name: Custom display name for the tuned model.
|
||||
max_context_length: The max context length used for tuning.
|
||||
Can be either '8k' or '32k'
|
||||
|
||||
Returns:
|
||||
A tuning job for distillation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model does not support distillation.
|
||||
"""
|
||||
if "/models/" not in self._endpoint_name:
|
||||
raise RuntimeError(
|
||||
f"Model does not support distillation: {self._endpoint_name}"
|
||||
)
|
||||
student_short_model_id = self._endpoint_name.split("/")[-1]
|
||||
|
||||
if isinstance(teacher_model, str):
|
||||
teacher_short_model_id = teacher_model
|
||||
elif isinstance(teacher_model, _language_models._LanguageModel):
|
||||
if "/models/" not in teacher_model._endpoint_name:
|
||||
raise RuntimeError(
|
||||
f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
|
||||
)
|
||||
teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")
|
||||
|
||||
pipeline_job = submit_distillation_pipeline_job(
|
||||
teacher_model=teacher_short_model_id,
|
||||
student_model=student_short_model_id,
|
||||
dataset=dataset,
|
||||
train_steps=train_steps,
|
||||
learning_rate_multiplier=learning_rate_multiplier,
|
||||
evaluation_spec=evaluation_spec,
|
||||
accelerator_type=accelerator_type,
|
||||
model_display_name=model_display_name,
|
||||
max_context_length=max_context_length,
|
||||
)
|
||||
tuning_job = tuning._LanguageModelTuningJob(
|
||||
base_model=self,
|
||||
job=pipeline_job,
|
||||
)
|
||||
return tuning_job
|
||||
|
||||
|
||||
def submit_distillation_pipeline_job(
|
||||
*,
|
||||
teacher_model: str,
|
||||
student_model: str,
|
||||
dataset: str,
|
||||
train_steps: Optional[int] = None,
|
||||
learning_rate_multiplier: Optional[float] = None,
|
||||
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
|
||||
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
|
||||
model_display_name: Optional[str] = None,
|
||||
max_context_length: Optional[str] = None,
|
||||
):
|
||||
teacher_short_model_id = teacher_model.split("/")[-1]
|
||||
student_short_model_id = student_model.split("/")[-1]
|
||||
pipeline_arguments = {
|
||||
"teacher_model_reference": teacher_model,
|
||||
"student_model_reference": student_model,
|
||||
"dataset_uri": dataset,
|
||||
"project": aiplatform_initializer.global_config.project,
|
||||
"location": aiplatform_initializer.global_config.location,
|
||||
}
|
||||
if train_steps is not None:
|
||||
pipeline_arguments["train_steps"] = train_steps
|
||||
if learning_rate_multiplier is not None:
|
||||
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
|
||||
if evaluation_spec is not None:
|
||||
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
|
||||
pipeline_arguments["evaluation_interval"] = evaluation_spec.evaluation_interval
|
||||
pipeline_arguments[
|
||||
"enable_early_stopping"
|
||||
] = evaluation_spec.enable_early_stopping
|
||||
pipeline_arguments[
|
||||
"enable_checkpoint_selection"
|
||||
] = evaluation_spec.enable_checkpoint_selection
|
||||
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
|
||||
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
|
||||
if accelerator_type is not None:
|
||||
pipeline_arguments["accelerator_type"] = accelerator_type
|
||||
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
|
||||
pipeline_arguments[
|
||||
"encryption_spec_key_name"
|
||||
] = aiplatform_initializer.global_config.encryption_spec_key_name
|
||||
if max_context_length is not None:
|
||||
pipeline_arguments["max_context_length"] = max_context_length
|
||||
if model_display_name is None:
|
||||
model_display_name = (
|
||||
f"{student_short_model_id} distilled from {teacher_short_model_id}"
|
||||
)
|
||||
pipeline_arguments["model_display_name"] = model_display_name
|
||||
# # Not exposing these parameters:
|
||||
# temperature: Optional[float] = None,
|
||||
# tpu_training_skip_cmek: Optional[bool] = None,
|
||||
# api_endpoint: Optional[str] = None,
|
||||
# version: Optional[str] = None,
|
||||
pipeline_job = aiplatform.PipelineJob(
|
||||
template_path=_DISTILLATION_PIPELINE_URI,
|
||||
display_name=None,
|
||||
parameter_values=pipeline_arguments,
|
||||
)
|
||||
pipeline_job.submit()
|
||||
return pipeline_job
|
||||
Reference in New Issue
Block a user