142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
from typing import Optional, Union
|
|
|
|
from google.cloud import aiplatform
|
|
from google.cloud.aiplatform import initializer as aiplatform_initializer
|
|
from vertexai.language_models import _language_models
|
|
from vertexai.language_models import _language_models as tuning
|
|
|
|
|
|
_DISTILLATION_PIPELINE_URI = (
|
|
"https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
|
|
)
|
|
|
|
|
|
class DistillationMixin:
|
|
def distill_from(
|
|
self,
|
|
*,
|
|
dataset: str,
|
|
teacher_model: Union[str, _language_models._TextGenerationModel],
|
|
train_steps: Optional[int] = None,
|
|
learning_rate_multiplier: Optional[float] = None,
|
|
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
|
|
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
|
|
model_display_name: Optional[str] = None,
|
|
max_context_length: Optional[str] = None,
|
|
):
|
|
"""Tunes a smaller model with help from another bigger model.
|
|
|
|
Args:
|
|
dataset: A URI pointing to data in JSON lines format.
|
|
teacher_model: The teacher model to use for distillation.
|
|
train_steps: Number of training batches to use (batch size is 8 samples).
|
|
learning_rate_multiplier: Learning rate multiplier to use in tuning.
|
|
evaluation_spec: Specification for the model evaluation during tuning.
|
|
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
|
|
model_display_name: Custom display name for the tuned model.
|
|
max_context_length: The max context length used for tuning.
|
|
Can be either '8k' or '32k'
|
|
|
|
Returns:
|
|
A tuning job for distillation.
|
|
|
|
Raises:
|
|
RuntimeError: If the model does not support distillation.
|
|
"""
|
|
if "/models/" not in self._endpoint_name:
|
|
raise RuntimeError(
|
|
f"Model does not support distillation: {self._endpoint_name}"
|
|
)
|
|
student_short_model_id = self._endpoint_name.split("/")[-1]
|
|
|
|
if isinstance(teacher_model, str):
|
|
teacher_short_model_id = teacher_model
|
|
elif isinstance(teacher_model, _language_models._LanguageModel):
|
|
if "/models/" not in teacher_model._endpoint_name:
|
|
raise RuntimeError(
|
|
f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
|
|
)
|
|
teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
|
|
else:
|
|
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")
|
|
|
|
pipeline_job = submit_distillation_pipeline_job(
|
|
teacher_model=teacher_short_model_id,
|
|
student_model=student_short_model_id,
|
|
dataset=dataset,
|
|
train_steps=train_steps,
|
|
learning_rate_multiplier=learning_rate_multiplier,
|
|
evaluation_spec=evaluation_spec,
|
|
accelerator_type=accelerator_type,
|
|
model_display_name=model_display_name,
|
|
max_context_length=max_context_length,
|
|
)
|
|
tuning_job = tuning._LanguageModelTuningJob(
|
|
base_model=self,
|
|
job=pipeline_job,
|
|
)
|
|
return tuning_job
|
|
|
|
|
|
def submit_distillation_pipeline_job(
|
|
*,
|
|
teacher_model: str,
|
|
student_model: str,
|
|
dataset: str,
|
|
train_steps: Optional[int] = None,
|
|
learning_rate_multiplier: Optional[float] = None,
|
|
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
|
|
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
|
|
model_display_name: Optional[str] = None,
|
|
max_context_length: Optional[str] = None,
|
|
):
|
|
teacher_short_model_id = teacher_model.split("/")[-1]
|
|
student_short_model_id = student_model.split("/")[-1]
|
|
pipeline_arguments = {
|
|
"teacher_model_reference": teacher_model,
|
|
"student_model_reference": student_model,
|
|
"dataset_uri": dataset,
|
|
"project": aiplatform_initializer.global_config.project,
|
|
"location": aiplatform_initializer.global_config.location,
|
|
}
|
|
if train_steps is not None:
|
|
pipeline_arguments["train_steps"] = train_steps
|
|
if learning_rate_multiplier is not None:
|
|
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
|
|
if evaluation_spec is not None:
|
|
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
|
|
pipeline_arguments["evaluation_interval"] = evaluation_spec.evaluation_interval
|
|
pipeline_arguments[
|
|
"enable_early_stopping"
|
|
] = evaluation_spec.enable_early_stopping
|
|
pipeline_arguments[
|
|
"enable_checkpoint_selection"
|
|
] = evaluation_spec.enable_checkpoint_selection
|
|
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
|
|
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
|
|
if accelerator_type is not None:
|
|
pipeline_arguments["accelerator_type"] = accelerator_type
|
|
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
|
|
pipeline_arguments[
|
|
"encryption_spec_key_name"
|
|
] = aiplatform_initializer.global_config.encryption_spec_key_name
|
|
if max_context_length is not None:
|
|
pipeline_arguments["max_context_length"] = max_context_length
|
|
if model_display_name is None:
|
|
model_display_name = (
|
|
f"{student_short_model_id} distilled from {teacher_short_model_id}"
|
|
)
|
|
pipeline_arguments["model_display_name"] = model_display_name
|
|
# # Not exposing these parameters:
|
|
# temperature: Optional[float] = None,
|
|
# tpu_training_skip_cmek: Optional[bool] = None,
|
|
# api_endpoint: Optional[str] = None,
|
|
# version: Optional[str] = None,
|
|
pipeline_job = aiplatform.PipelineJob(
|
|
template_path=_DISTILLATION_PIPELINE_URI,
|
|
display_name=None,
|
|
parameter_values=pipeline_arguments,
|
|
)
|
|
pipeline_job.submit()
|
|
return pipeline_job
|