Files
evo-ai/.venv/lib/python3.10/site-packages/vertexai/prompts/_prompt_management.py
2025-04-25 15:30:54 -03:00

740 lines
26 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import (
dataset_version as gca_dataset_version,
)
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
)
from vertexai.generative_models import (
Part,
Image,
GenerativeModel,
Tool,
ToolConfig,
)
from vertexai.generative_models._generative_models import (
_proto_to_dict,
_dict_to_proto,
_tool_types_to_gapic_tools,
PartsType,
)
from vertexai.prompts._prompts import Prompt
from google.protobuf import field_mask_pb2 as field_mask
import dataclasses
from typing import (
Any,
Dict,
Optional,
)
_LOGGER = base.Logger(__name__)
_dataset_client_value = None
DEFAULT_API_SCHEMA_VERSION = "1.0.0"
PROMPT_SCHEMA_URI = (
"gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml"
)
def _format_function_declaration_parameters(obj: Any):
"""Recursively replaces type_ and format_ fields in-place."""
if isinstance(obj, (str, int, float)):
return obj
if isinstance(obj, dict):
new = obj.__class__()
for key, value in obj.items():
key = key.replace("type_", "type")
key = key.replace("format_", "format")
new[key] = _format_function_declaration_parameters(value)
elif isinstance(obj, (list, set, tuple)):
new = obj.__class__(
_format_function_declaration_parameters(value) for value in obj
)
else:
return obj
return new
@dataclasses.dataclass
class Arguments:
"""Arguments. Child of Execution.
Attributes:
variables: The arguments of the execution.
"""
variables: dict[str, list[Part]]
def to_dict(self) -> Dict[str, Any]:
dct = {}
for variable_name in self.variables:
dct[variable_name] = {
"partList": {
"parts": [part.to_dict() for part in self.variables[variable_name]]
}
}
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "Arguments":
variables = {}
for variable_name in dct:
variables[variable_name] = [
Part.from_dict(part) for part in dct[variable_name]["partList"]["parts"]
]
arguments = cls(variables=variables)
return arguments
@dataclasses.dataclass
class Execution:
"""Execution. Child of MultimodalPrompt.
Attributes:
arguments: The arguments of the execution.
"""
arguments: Arguments
def __init__(self, arguments: dict[str, list[Part]]):
self.arguments = Arguments(variables=arguments)
def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["arguments"] = self.arguments.to_dict()
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "Execution":
arguments = dct.get("arguments", None)
execution = cls(arguments=arguments)
return execution
@dataclasses.dataclass
class MultimodalPrompt:
"""MultimodalPrompt. Child of PromptDatasetMetadata.
Attributes:
prompt_message: The schema for the prompt. A subset of the GenerateContentRequest schema.
api_schema_version: The api schema version of the prompt when it was last modified.
executions: Contains data related to an execution of a prompt (ex. variables)
"""
prompt_message: gapic_prediction_service_types.GenerateContentRequest
api_schema_version: Optional[str] = DEFAULT_API_SCHEMA_VERSION
executions: Optional[list[Execution]] = None
def to_dict(self) -> Dict[str, Any]:
dct = {"multimodalPrompt": {}}
dct["apiSchemaVersion"] = self.api_schema_version
dct["multimodalPrompt"]["promptMessage"] = _proto_to_dict(self.prompt_message)
# Fix type_ and format_ fields
if dct["multimodalPrompt"]["promptMessage"].get("tools", None):
tools = dct["multimodalPrompt"]["promptMessage"]["tools"]
for tool in tools:
for function_declaration in tool.get("function_declarations", []):
function_declaration[
"parameters"
] = _format_function_declaration_parameters(
function_declaration["parameters"]
)
if self.executions and self.executions[0]:
# Only add variable sets if they are non empty.
execution_dcts = []
for execution in self.executions:
exeuction_dct = execution.to_dict()
if exeuction_dct and exeuction_dct["arguments"]:
execution_dcts.append(exeuction_dct)
if execution_dcts:
dct["executions"] = execution_dcts
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "MultimodalPrompt":
api_schema_version = dct.get("apiSchemaVersion", DEFAULT_API_SCHEMA_VERSION)
if int(api_schema_version.split(".")[0]) > int(
DEFAULT_API_SCHEMA_VERSION.split(".")[0]
):
# Disallow loading prompts with higher major schema version
raise ValueError(
"This prompt was saved with a newer schema version and cannot be loaded."
)
prompt_message_dct = dct.get("multimodalPrompt", {}).get("promptMessage", None)
if not prompt_message_dct:
raise ValueError("This prompt is not supported in the SDK.")
# Tool function declaration will fail the proto conversion
tools = prompt_message_dct.get("tools", None)
if tools:
tools = [Tool.from_dict(tool) for tool in tools]
prompt_message_dct.pop("tools")
prompt_message = _dict_to_proto(
gapic_prediction_service_types.GenerateContentRequest, prompt_message_dct
)
if tools:
# Convert Tools to gapic to store in the prompt_message
prompt_message.tools = _tool_types_to_gapic_tools(tools)
executions_dct = dct.get("executions", [])
executions = [Execution.from_dict(execution) for execution in executions_dct]
if not executions:
executions = None
multimodal_prompt = cls(
prompt_message=prompt_message,
api_schema_version=api_schema_version,
executions=executions,
)
return multimodal_prompt
@dataclasses.dataclass
class PromptDatasetMetadata:
"""PromptDatasetMetadata.
Attributes:
prompt_type: Required. SDK only supports "freeform" or "multimodal_freeform"
prompt_api_schema: Required. SDK only supports multimodalPrompt
"""
prompt_type: str
prompt_api_schema: MultimodalPrompt
def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["promptType"] = self.prompt_type
dct["promptApiSchema"] = self.prompt_api_schema.to_dict()
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "PromptDatasetMetadata":
metadata = cls(
prompt_type=dct.get("promptType", None),
prompt_api_schema=MultimodalPrompt.from_dict(
dct.get("promptApiSchema", None)
),
)
return metadata
@dataclasses.dataclass
class PromptMetadata:
"""Metadata containing the display name and prompt id of a prompt.
Returned by the `list_prompts` method.
Attributes:
display_name: The display name of the prompt version.
prompt_id: The id of the prompt.
"""
display_name: str
prompt_id: str
@dataclasses.dataclass
class PromptVersionMetadata:
"""Metadata containing the display name, prompt id, and version id of a prompt version.
Returned by the `list_prompt_versions` method.
Attributes:
display_name: The display name of the prompt version.
prompt_id: The id of the prompt.
version_id: The version id of the prompt.
"""
display_name: str
prompt_id: str
version_id: str
def create_version(
prompt: Prompt,
prompt_id: Optional[str] = None,
version_name: Optional[str] = None,
) -> Prompt:
"""Creates a Prompt or Prompt Version in the online prompt store
Args:
prompt: The Prompt object to create a new version of.
prompt_id: The id of the prompt resource to create a new version under.
If it is not provided and the prompt has no prompt resource
associated with it, a new prompt resource will be created.
version_name: Optional display name of the new prompt version.
If not specified, a default name including a timestamp will be used.
Returns:
A new Prompt object with a reference to the newly created or updated
prompt resource. This new Prompt object is nearly identical to the
original Prompt object, except it has references to the new
prompt version.
"""
if not (prompt_id or prompt._dataset):
# Case 1: Neither prompt id nor prompt._dataset exists, so we
# create a new prompt resource
return _create_prompt_resource(prompt=prompt, version_name=version_name)
# Case 2: No prompt_id override is given, so we update the existing prompt resource
if not prompt_id:
return _create_prompt_version_resource(prompt=prompt, version_name=version_name)
# Case 3: Save a new version to the prompt_id provided as an arg
# prompt_id is guaranteed to exist due to Cases 1 & 2 being handled before
# Store the original prompt resource name, if it exists
original_prompt_resource = None if not prompt._dataset else prompt._dataset.name
# Create a gapic dataset object if it doesn't exist
if not prompt._dataset:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset_metadata = _format_dataset_metadata_dict(prompt=prompt)
prompt._dataset = gca_dataset.Dataset(
name=name,
display_name=prompt.prompt_name or "Untitled Prompt",
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_metadata,
model_reference=prompt.model_name,
)
# Override the dataset prompt id with the new prompt id
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
prompt._dataset.name = (
f"projects/{project}/locations/{location}/datasets/{prompt_id}"
)
result = _create_prompt_version_resource(prompt=prompt, version_name=version_name)
# Restore the original prompt resource name. This is a no-op if there
# was no original prompt resource name.
prompt._dataset.name = original_prompt_resource
return result
def _check_multimodal_contents(prompt_data: PartsType):
if isinstance(prompt_data, Image):
return "multimodal_freeform"
elif isinstance(prompt_data, list):
for part in prompt_data:
check = _check_multimodal_contents(part)
if check == "multimodal_freeform":
return "multimodal_freeform"
elif isinstance(prompt_data, Part):
if "text" not in prompt_data._raw_part:
return "multimodal_freeform"
return "freeform"
def _format_dataset_metadata_dict(prompt: Prompt) -> dict[str, Any]:
"""Helper function to convert the configs and prompt data stored in the Prompt object to a dataset metadata dict."""
model = GenerativeModel(model_name=prompt.model_name)
prompt_message = model._prepare_request(
contents=prompt.prompt_data or "temporary data",
model=prompt.model_name,
system_instruction=prompt.system_instruction,
tools=prompt.tools,
tool_config=prompt.tool_config,
safety_settings=prompt.safety_settings,
generation_config=prompt.generation_config,
)
# Remove temporary contents
if not prompt.prompt_data:
prompt_message.contents = None
# Stopgap solution to check for multimodal contents to set flag for UI
if prompt.prompt_data:
prompt_type = _check_multimodal_contents(prompt.prompt_data)
else:
prompt_type = "freeform"
return PromptDatasetMetadata(
prompt_type=prompt_type,
prompt_api_schema=MultimodalPrompt(
prompt_message=prompt_message,
executions=[Execution(variable_set) for variable_set in prompt.variables],
),
).to_dict()
def _create_dataset(prompt: Prompt, parent: str) -> gca_dataset.Dataset:
dataset_metadata = _format_dataset_metadata_dict(prompt=prompt)
dataset = gca_dataset.Dataset(
name=parent,
display_name=prompt.prompt_name or "Untitled Prompt",
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_metadata,
model_reference=prompt.model_name,
)
operation = prompt._dataset_client.create_dataset(
parent=parent,
dataset=dataset,
)
dataset = operation.result()
# Purge labels
dataset.labels = None
return dataset
def _create_dataset_version(
prompt: Prompt, parent: str, version_name: Optional[str] = None
):
dataset_version = gca_dataset_version.DatasetVersion(
display_name=version_name,
)
dataset_version = prompt._dataset_client.create_dataset_version(
parent=parent,
dataset_version=dataset_version,
)
return dataset_version.result()
def _update_dataset(
prompt: Prompt,
dataset: gca_dataset.Dataset,
) -> gca_dataset_version.DatasetVersion:
dataset.metadata = _format_dataset_metadata_dict(prompt=prompt)
mask_paths = ["modelReference", "metadata"]
if dataset.display_name != "Untitled Prompt":
mask_paths.append("displayName")
updated_dataset = prompt._dataset_client.update_dataset(
dataset=dataset,
update_mask=field_mask.FieldMask(paths=mask_paths),
)
# Remove etag to avoid error for repeated dataset updates
updated_dataset.etag = None
return updated_dataset
def _create_prompt_resource(
prompt: Prompt, version_name: Optional[str] = None
) -> Prompt:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
# Step 1: Create prompt dataset API call
parent = f"projects/{project}/locations/{location}"
dataset = _create_dataset(prompt=prompt, parent=parent)
# Step 2: Create prompt version API call
dataset_version = _create_dataset_version(
prompt=prompt,
parent=dataset.name,
version_name=version_name,
)
# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
prompt_id = new_prompt._dataset.name.split("/")[5]
_LOGGER.info(
f"Created prompt resource with id {prompt_id} with version number {new_prompt._version_id}"
)
return new_prompt
def _create_prompt_version_resource(
prompt: Prompt,
version_name: Optional[str] = None,
) -> Prompt:
# Step 1: Update prompt API call
updated_dataset = _update_dataset(prompt=prompt, dataset=prompt._dataset)
# Step 2: Create prompt version API call
dataset_version = _create_dataset_version(
prompt=prompt,
parent=updated_dataset.name,
version_name=version_name,
)
# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = updated_dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
prompt_id = prompt._dataset.name.split("/")[5]
_LOGGER.info(
f"Updated prompt resource with id {prompt_id} as version number {new_prompt._version_id}"
)
return new_prompt
def _get_prompt_resource(prompt: Prompt, prompt_id: str) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt id."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)
return dataset
def _get_prompt_resource_from_version(
prompt: Prompt, prompt_id: str, version_id: str
) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt version id."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
# Step 1: Get dataset version object
dataset_version = prompt._dataset_client.get_dataset_version(name=name)
prompt._version_name = dataset_version.display_name
# Step 2: Fetch dataset object to get the dataset display name
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)
# Step 3: Convert to DatasetVersion object to Dataset object
dataset = gca_dataset.Dataset(
name=name,
display_name=dataset.display_name,
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_version.metadata,
model_reference=dataset_version.model_reference,
)
return dataset
def restore_version(prompt_id: str, version_id: str) -> PromptVersionMetadata:
"""Restores a previous version of the prompt resource and
loads that version into the current Prompt object.
Args:
prompt_id: The id of the prompt resource to restore a version of.
version_id: The version id of the online prompt resource.
"""
# Step 1: Make restore dataset version API call
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
operation = temp_prompt._dataset_client.restore_dataset_version(name=name)
result = operation.result()
new_version_id = result.name.split("/")[-1]
prompt_id = result.name.split("/")[5]
_LOGGER.info(
f"Restored prompt version {version_id} under prompt id {prompt_id} as version number {new_version_id}"
)
# Step 2: Create PromptVersionMetadata object to return
return PromptVersionMetadata(
display_name=result.display_name,
prompt_id=result.name.split("/")[5],
version_id=new_version_id,
)
def get(prompt_id: str, version_id: Optional[str] = None) -> Prompt:
"""Creates a Prompt object from an online resource.
Args:
prompt_id: The id of the prompt resource.
version_id: Optional version id of the prompt resource.
If not specified, the latest version will be used.
Returns:
A prompt loaded from the online resource as a `Prompt` object.
"""
prompt = Prompt()
if version_id:
dataset = _get_prompt_resource_from_version(
prompt=prompt,
prompt_id=prompt_id,
version_id=version_id,
)
else:
dataset = _get_prompt_resource(prompt=prompt, prompt_id=prompt_id)
# Remove etag to avoid error for repeated dataset updates
dataset.etag = None
prompt._dataset = dataset
prompt._version_id = version_id
dataset_dict = _proto_to_dict(dataset)
metadata = PromptDatasetMetadata.from_dict(dataset_dict["metadata"])
_populate_fields_from_metadata(prompt=prompt, metadata=metadata)
return prompt
def _populate_fields_from_metadata(
prompt: Prompt, metadata: PromptDatasetMetadata
) -> None:
"""Helper function to populate Promptfields from metadata object"""
# Populate model_name (Required, raw deserialized type is str)
prompt.model_name = metadata.prompt_api_schema.prompt_message.model
# Populate prompt_data (raw deserialized type is list[Content])
contents = metadata.prompt_api_schema.prompt_message.contents
if contents:
if len(contents) > 1:
raise ValueError("Multi-turn prompts are not supported yet.")
prompt_data = [Part._from_gapic(part) for part in list(contents[0].parts)]
# Unwrap single text part into str
if len(prompt_data) == 1 and "text" in prompt_data[0]._raw_part:
prompt.prompt_data = prompt_data[0].text
else:
prompt.prompt_data = prompt_data
# Populate system_instruction (raw deserialized type is single Content)
system_instruction = metadata.prompt_api_schema.prompt_message.system_instruction
if system_instruction:
system_instruction_parts = [
Part._from_gapic(part) for part in list(system_instruction.parts)
]
# Unwrap single text part into str
if len(system_instruction_parts) == 1 and system_instruction_parts[0].text:
prompt.system_instruction = system_instruction_parts[0].text
else:
prompt.system_instruction = system_instruction_parts
# Populate variables
executions = metadata.prompt_api_schema.executions
variables = []
if executions:
for execution in executions:
serialized_variable_set = execution.arguments
variable_set = {}
if serialized_variable_set:
for name, value in serialized_variable_set.variables.items():
# Parts are dicts, not gapic objects for variables
variable_set[name] = [
Part.from_dict(part)
for part in list(value["partList"]["parts"])
]
variables.append(variable_set)
# Unwrap variable single text part into str
for variable_set in variables:
for name, value in variable_set.items():
if len(value) == 1 and "text" in value[0]._raw_part:
variable_set[name] = value[0].text
prompt.variables = variables
# Populate generation_config (raw deserialized type is GenerationConfig)
generation_config = metadata.prompt_api_schema.prompt_message.generation_config
if generation_config:
prompt.generation_config = generation_config
# Populate safety_settings (raw deserialized type is RepeatedComposite of SafetySetting)
safety_settings = metadata.prompt_api_schema.prompt_message.safety_settings
if safety_settings:
prompt.safety_settings = list(safety_settings)
# Populate tools (raw deserialized type is RepeatedComposite of Tool)
tools = metadata.prompt_api_schema.prompt_message.tools
if tools:
prompt.tools = list(tools)
# Populate tool_config (raw deserialized type is ToolConfig)
tool_config = metadata.prompt_api_schema.prompt_message.tool_config
if tool_config:
prompt.tool_config = ToolConfig._from_gapic(tool_config)
def list_prompts() -> list[PromptMetadata]:
"""Lists all prompt resources in the online prompt store associated with the project."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
parent = f"projects/{project}/locations/{location}"
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
prompts_pager = temp_prompt._dataset_client.list_datasets(
parent=parent,
)
prompts_list = []
for prompt in prompts_pager:
prompts_list.append(
PromptMetadata(
display_name=prompt.display_name,
prompt_id=prompt.name.split("/")[5],
)
)
return prompts_list
def list_versions(prompt_id: str) -> list[PromptVersionMetadata]:
"""Returns a list of PromptVersionMetadata objects for the prompt resource.
Args:
prompt_id: The id of the prompt resource to list versions of.
Returns:
A list of PromptVersionMetadata objects for the prompt resource.
"""
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
parent = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
versions_pager = temp_prompt._dataset_client.list_dataset_versions(
parent=parent,
)
version_history = []
for version in versions_pager:
version_history.append(
PromptVersionMetadata(
display_name=version.display_name,
prompt_id=version.name.split("/")[5],
version_id=version.name.split("/")[-1],
)
)
return version_history
def delete(prompt_id: str) -> None:
"""Deletes the online prompt resource associated with the prompt id."""
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
operation = temp_prompt._dataset_client.delete_dataset(
name=name,
)
operation.result()
_LOGGER.info(f"Deleted prompt resource with id {prompt_id}.")