Files
evo-ai/.venv/lib/python3.10/site-packages/vertexai/language_models/_distillation.py
2025-04-25 15:30:54 -03:00

142 lines
5.9 KiB
Python

from typing import Optional, Union
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer as aiplatform_initializer
from vertexai.language_models import _language_models
from vertexai.language_models import _language_models as tuning
_DISTILLATION_PIPELINE_URI = (
"https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
)
class DistillationMixin:
def distill_from(
self,
*,
dataset: str,
teacher_model: Union[str, _language_models._TextGenerationModel],
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
model_display_name: Optional[str] = None,
max_context_length: Optional[str] = None,
):
"""Tunes a smaller model with help from another bigger model.
Args:
dataset: A URI pointing to data in JSON lines format.
teacher_model: The teacher model to use for distillation.
train_steps: Number of training batches to use (batch size is 8 samples).
learning_rate_multiplier: Learning rate multiplier to use in tuning.
evaluation_spec: Specification for the model evaluation during tuning.
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
model_display_name: Custom display name for the tuned model.
max_context_length: The max context length used for tuning.
Can be either '8k' or '32k'
Returns:
A tuning job for distillation.
Raises:
RuntimeError: If the model does not support distillation.
"""
if "/models/" not in self._endpoint_name:
raise RuntimeError(
f"Model does not support distillation: {self._endpoint_name}"
)
student_short_model_id = self._endpoint_name.split("/")[-1]
if isinstance(teacher_model, str):
teacher_short_model_id = teacher_model
elif isinstance(teacher_model, _language_models._LanguageModel):
if "/models/" not in teacher_model._endpoint_name:
raise RuntimeError(
f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
)
teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
else:
raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")
pipeline_job = submit_distillation_pipeline_job(
teacher_model=teacher_short_model_id,
student_model=student_short_model_id,
dataset=dataset,
train_steps=train_steps,
learning_rate_multiplier=learning_rate_multiplier,
evaluation_spec=evaluation_spec,
accelerator_type=accelerator_type,
model_display_name=model_display_name,
max_context_length=max_context_length,
)
tuning_job = tuning._LanguageModelTuningJob(
base_model=self,
job=pipeline_job,
)
return tuning_job
def submit_distillation_pipeline_job(
*,
teacher_model: str,
student_model: str,
dataset: str,
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
model_display_name: Optional[str] = None,
max_context_length: Optional[str] = None,
):
teacher_short_model_id = teacher_model.split("/")[-1]
student_short_model_id = student_model.split("/")[-1]
pipeline_arguments = {
"teacher_model_reference": teacher_model,
"student_model_reference": student_model,
"dataset_uri": dataset,
"project": aiplatform_initializer.global_config.project,
"location": aiplatform_initializer.global_config.location,
}
if train_steps is not None:
pipeline_arguments["train_steps"] = train_steps
if learning_rate_multiplier is not None:
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
if evaluation_spec is not None:
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
pipeline_arguments["evaluation_interval"] = evaluation_spec.evaluation_interval
pipeline_arguments[
"enable_early_stopping"
] = evaluation_spec.enable_early_stopping
pipeline_arguments[
"enable_checkpoint_selection"
] = evaluation_spec.enable_checkpoint_selection
pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
# pipeline_parameter_values["evaluation_output_root_dir"] = ...
if accelerator_type is not None:
pipeline_arguments["accelerator_type"] = accelerator_type
if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
pipeline_arguments[
"encryption_spec_key_name"
] = aiplatform_initializer.global_config.encryption_spec_key_name
if max_context_length is not None:
pipeline_arguments["max_context_length"] = max_context_length
if model_display_name is None:
model_display_name = (
f"{student_short_model_id} distilled from {teacher_short_model_id}"
)
pipeline_arguments["model_display_name"] = model_display_name
# # Not exposing these parameters:
# temperature: Optional[float] = None,
# tpu_training_skip_cmek: Optional[bool] = None,
# api_endpoint: Optional[str] = None,
# version: Optional[str] = None,
pipeline_job = aiplatform.PipelineJob(
template_path=_DISTILLATION_PIPELINE_URI,
display_name=None,
parameter_values=pipeline_arguments,
)
pipeline_job.submit()
return pipeline_job