structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,739 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import (
dataset_version as gca_dataset_version,
)
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as gapic_prediction_service_types,
)
from vertexai.generative_models import (
Part,
Image,
GenerativeModel,
Tool,
ToolConfig,
)
from vertexai.generative_models._generative_models import (
_proto_to_dict,
_dict_to_proto,
_tool_types_to_gapic_tools,
PartsType,
)
from vertexai.prompts._prompts import Prompt
from google.protobuf import field_mask_pb2 as field_mask
import dataclasses
from typing import (
Any,
Dict,
Optional,
)
_LOGGER = base.Logger(__name__)
_dataset_client_value = None
DEFAULT_API_SCHEMA_VERSION = "1.0.0"
PROMPT_SCHEMA_URI = (
"gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml"
)
def _format_function_declaration_parameters(obj: Any):
"""Recursively replaces type_ and format_ fields in-place."""
if isinstance(obj, (str, int, float)):
return obj
if isinstance(obj, dict):
new = obj.__class__()
for key, value in obj.items():
key = key.replace("type_", "type")
key = key.replace("format_", "format")
new[key] = _format_function_declaration_parameters(value)
elif isinstance(obj, (list, set, tuple)):
new = obj.__class__(
_format_function_declaration_parameters(value) for value in obj
)
else:
return obj
return new
@dataclasses.dataclass
class Arguments:
"""Arguments. Child of Execution.
Attributes:
variables: The arguments of the execution.
"""
variables: dict[str, list[Part]]
def to_dict(self) -> Dict[str, Any]:
dct = {}
for variable_name in self.variables:
dct[variable_name] = {
"partList": {
"parts": [part.to_dict() for part in self.variables[variable_name]]
}
}
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "Arguments":
variables = {}
for variable_name in dct:
variables[variable_name] = [
Part.from_dict(part) for part in dct[variable_name]["partList"]["parts"]
]
arguments = cls(variables=variables)
return arguments
@dataclasses.dataclass
class Execution:
"""Execution. Child of MultimodalPrompt.
Attributes:
arguments: The arguments of the execution.
"""
arguments: Arguments
def __init__(self, arguments: dict[str, list[Part]]):
self.arguments = Arguments(variables=arguments)
def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["arguments"] = self.arguments.to_dict()
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "Execution":
arguments = dct.get("arguments", None)
execution = cls(arguments=arguments)
return execution
@dataclasses.dataclass
class MultimodalPrompt:
"""MultimodalPrompt. Child of PromptDatasetMetadata.
Attributes:
prompt_message: The schema for the prompt. A subset of the GenerateContentRequest schema.
api_schema_version: The api schema version of the prompt when it was last modified.
executions: Contains data related to an execution of a prompt (ex. variables)
"""
prompt_message: gapic_prediction_service_types.GenerateContentRequest
api_schema_version: Optional[str] = DEFAULT_API_SCHEMA_VERSION
executions: Optional[list[Execution]] = None
def to_dict(self) -> Dict[str, Any]:
dct = {"multimodalPrompt": {}}
dct["apiSchemaVersion"] = self.api_schema_version
dct["multimodalPrompt"]["promptMessage"] = _proto_to_dict(self.prompt_message)
# Fix type_ and format_ fields
if dct["multimodalPrompt"]["promptMessage"].get("tools", None):
tools = dct["multimodalPrompt"]["promptMessage"]["tools"]
for tool in tools:
for function_declaration in tool.get("function_declarations", []):
function_declaration[
"parameters"
] = _format_function_declaration_parameters(
function_declaration["parameters"]
)
if self.executions and self.executions[0]:
# Only add variable sets if they are non empty.
execution_dcts = []
for execution in self.executions:
exeuction_dct = execution.to_dict()
if exeuction_dct and exeuction_dct["arguments"]:
execution_dcts.append(exeuction_dct)
if execution_dcts:
dct["executions"] = execution_dcts
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "MultimodalPrompt":
api_schema_version = dct.get("apiSchemaVersion", DEFAULT_API_SCHEMA_VERSION)
if int(api_schema_version.split(".")[0]) > int(
DEFAULT_API_SCHEMA_VERSION.split(".")[0]
):
# Disallow loading prompts with higher major schema version
raise ValueError(
"This prompt was saved with a newer schema version and cannot be loaded."
)
prompt_message_dct = dct.get("multimodalPrompt", {}).get("promptMessage", None)
if not prompt_message_dct:
raise ValueError("This prompt is not supported in the SDK.")
# Tool function declaration will fail the proto conversion
tools = prompt_message_dct.get("tools", None)
if tools:
tools = [Tool.from_dict(tool) for tool in tools]
prompt_message_dct.pop("tools")
prompt_message = _dict_to_proto(
gapic_prediction_service_types.GenerateContentRequest, prompt_message_dct
)
if tools:
# Convert Tools to gapic to store in the prompt_message
prompt_message.tools = _tool_types_to_gapic_tools(tools)
executions_dct = dct.get("executions", [])
executions = [Execution.from_dict(execution) for execution in executions_dct]
if not executions:
executions = None
multimodal_prompt = cls(
prompt_message=prompt_message,
api_schema_version=api_schema_version,
executions=executions,
)
return multimodal_prompt
@dataclasses.dataclass
class PromptDatasetMetadata:
"""PromptDatasetMetadata.
Attributes:
prompt_type: Required. SDK only supports "freeform" or "multimodal_freeform"
prompt_api_schema: Required. SDK only supports multimodalPrompt
"""
prompt_type: str
prompt_api_schema: MultimodalPrompt
def to_dict(self) -> Dict[str, Any]:
dct = {}
dct["promptType"] = self.prompt_type
dct["promptApiSchema"] = self.prompt_api_schema.to_dict()
return dct
@classmethod
def from_dict(cls, dct: Dict[str, Any]) -> "PromptDatasetMetadata":
metadata = cls(
prompt_type=dct.get("promptType", None),
prompt_api_schema=MultimodalPrompt.from_dict(
dct.get("promptApiSchema", None)
),
)
return metadata
@dataclasses.dataclass
class PromptMetadata:
"""Metadata containing the display name and prompt id of a prompt.
Returned by the `list_prompts` method.
Attributes:
display_name: The display name of the prompt version.
prompt_id: The id of the prompt.
"""
display_name: str
prompt_id: str
@dataclasses.dataclass
class PromptVersionMetadata:
"""Metadata containing the display name, prompt id, and version id of a prompt version.
Returned by the `list_prompt_versions` method.
Attributes:
display_name: The display name of the prompt version.
prompt_id: The id of the prompt.
version_id: The version id of the prompt.
"""
display_name: str
prompt_id: str
version_id: str
def create_version(
prompt: Prompt,
prompt_id: Optional[str] = None,
version_name: Optional[str] = None,
) -> Prompt:
"""Creates a Prompt or Prompt Version in the online prompt store
Args:
prompt: The Prompt object to create a new version of.
prompt_id: The id of the prompt resource to create a new version under.
If it is not provided and the prompt has no prompt resource
associated with it, a new prompt resource will be created.
version_name: Optional display name of the new prompt version.
If not specified, a default name including a timestamp will be used.
Returns:
A new Prompt object with a reference to the newly created or updated
prompt resource. This new Prompt object is nearly identical to the
original Prompt object, except it has references to the new
prompt version.
"""
if not (prompt_id or prompt._dataset):
# Case 1: Neither prompt id nor prompt._dataset exists, so we
# create a new prompt resource
return _create_prompt_resource(prompt=prompt, version_name=version_name)
# Case 2: No prompt_id override is given, so we update the existing prompt resource
if not prompt_id:
return _create_prompt_version_resource(prompt=prompt, version_name=version_name)
# Case 3: Save a new version to the prompt_id provided as an arg
# prompt_id is guaranteed to exist due to Cases 1 & 2 being handled before
# Store the original prompt resource name, if it exists
original_prompt_resource = None if not prompt._dataset else prompt._dataset.name
# Create a gapic dataset object if it doesn't exist
if not prompt._dataset:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset_metadata = _format_dataset_metadata_dict(prompt=prompt)
prompt._dataset = gca_dataset.Dataset(
name=name,
display_name=prompt.prompt_name or "Untitled Prompt",
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_metadata,
model_reference=prompt.model_name,
)
# Override the dataset prompt id with the new prompt id
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
prompt._dataset.name = (
f"projects/{project}/locations/{location}/datasets/{prompt_id}"
)
result = _create_prompt_version_resource(prompt=prompt, version_name=version_name)
# Restore the original prompt resource name. This is a no-op if there
# was no original prompt resource name.
prompt._dataset.name = original_prompt_resource
return result
def _check_multimodal_contents(prompt_data: PartsType):
if isinstance(prompt_data, Image):
return "multimodal_freeform"
elif isinstance(prompt_data, list):
for part in prompt_data:
check = _check_multimodal_contents(part)
if check == "multimodal_freeform":
return "multimodal_freeform"
elif isinstance(prompt_data, Part):
if "text" not in prompt_data._raw_part:
return "multimodal_freeform"
return "freeform"
def _format_dataset_metadata_dict(prompt: Prompt) -> dict[str, Any]:
"""Helper function to convert the configs and prompt data stored in the Prompt object to a dataset metadata dict."""
model = GenerativeModel(model_name=prompt.model_name)
prompt_message = model._prepare_request(
contents=prompt.prompt_data or "temporary data",
model=prompt.model_name,
system_instruction=prompt.system_instruction,
tools=prompt.tools,
tool_config=prompt.tool_config,
safety_settings=prompt.safety_settings,
generation_config=prompt.generation_config,
)
# Remove temporary contents
if not prompt.prompt_data:
prompt_message.contents = None
# Stopgap solution to check for multimodal contents to set flag for UI
if prompt.prompt_data:
prompt_type = _check_multimodal_contents(prompt.prompt_data)
else:
prompt_type = "freeform"
return PromptDatasetMetadata(
prompt_type=prompt_type,
prompt_api_schema=MultimodalPrompt(
prompt_message=prompt_message,
executions=[Execution(variable_set) for variable_set in prompt.variables],
),
).to_dict()
def _create_dataset(prompt: Prompt, parent: str) -> gca_dataset.Dataset:
dataset_metadata = _format_dataset_metadata_dict(prompt=prompt)
dataset = gca_dataset.Dataset(
name=parent,
display_name=prompt.prompt_name or "Untitled Prompt",
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_metadata,
model_reference=prompt.model_name,
)
operation = prompt._dataset_client.create_dataset(
parent=parent,
dataset=dataset,
)
dataset = operation.result()
# Purge labels
dataset.labels = None
return dataset
def _create_dataset_version(
prompt: Prompt, parent: str, version_name: Optional[str] = None
):
dataset_version = gca_dataset_version.DatasetVersion(
display_name=version_name,
)
dataset_version = prompt._dataset_client.create_dataset_version(
parent=parent,
dataset_version=dataset_version,
)
return dataset_version.result()
def _update_dataset(
prompt: Prompt,
dataset: gca_dataset.Dataset,
) -> gca_dataset_version.DatasetVersion:
dataset.metadata = _format_dataset_metadata_dict(prompt=prompt)
mask_paths = ["modelReference", "metadata"]
if dataset.display_name != "Untitled Prompt":
mask_paths.append("displayName")
updated_dataset = prompt._dataset_client.update_dataset(
dataset=dataset,
update_mask=field_mask.FieldMask(paths=mask_paths),
)
# Remove etag to avoid error for repeated dataset updates
updated_dataset.etag = None
return updated_dataset
def _create_prompt_resource(
prompt: Prompt, version_name: Optional[str] = None
) -> Prompt:
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
# Step 1: Create prompt dataset API call
parent = f"projects/{project}/locations/{location}"
dataset = _create_dataset(prompt=prompt, parent=parent)
# Step 2: Create prompt version API call
dataset_version = _create_dataset_version(
prompt=prompt,
parent=dataset.name,
version_name=version_name,
)
# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
prompt_id = new_prompt._dataset.name.split("/")[5]
_LOGGER.info(
f"Created prompt resource with id {prompt_id} with version number {new_prompt._version_id}"
)
return new_prompt
def _create_prompt_version_resource(
prompt: Prompt,
version_name: Optional[str] = None,
) -> Prompt:
# Step 1: Update prompt API call
updated_dataset = _update_dataset(prompt=prompt, dataset=prompt._dataset)
# Step 2: Create prompt version API call
dataset_version = _create_dataset_version(
prompt=prompt,
parent=updated_dataset.name,
version_name=version_name,
)
# Step 3: Create new Prompt object to return
new_prompt = Prompt._clone(prompt=prompt)
new_prompt._dataset = updated_dataset
new_prompt._version_id = dataset_version.name.split("/")[-1]
new_prompt._version_name = dataset_version.display_name
prompt_id = prompt._dataset.name.split("/")[5]
_LOGGER.info(
f"Updated prompt resource with id {prompt_id} as version number {new_prompt._version_id}"
)
return new_prompt
def _get_prompt_resource(prompt: Prompt, prompt_id: str) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt id."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)
return dataset
def _get_prompt_resource_from_version(
prompt: Prompt, prompt_id: str, version_id: str
) -> gca_dataset.Dataset:
"""Helper function to get a prompt resource from a prompt version id."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
# Step 1: Get dataset version object
dataset_version = prompt._dataset_client.get_dataset_version(name=name)
prompt._version_name = dataset_version.display_name
# Step 2: Fetch dataset object to get the dataset display name
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
dataset = prompt._dataset_client.get_dataset(name=name)
# Step 3: Convert to DatasetVersion object to Dataset object
dataset = gca_dataset.Dataset(
name=name,
display_name=dataset.display_name,
metadata_schema_uri=PROMPT_SCHEMA_URI,
metadata=dataset_version.metadata,
model_reference=dataset_version.model_reference,
)
return dataset
def restore_version(prompt_id: str, version_id: str) -> PromptVersionMetadata:
"""Restores a previous version of the prompt resource and
loads that version into the current Prompt object.
Args:
prompt_id: The id of the prompt resource to restore a version of.
version_id: The version id of the online prompt resource.
"""
# Step 1: Make restore dataset version API call
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}"
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
operation = temp_prompt._dataset_client.restore_dataset_version(name=name)
result = operation.result()
new_version_id = result.name.split("/")[-1]
prompt_id = result.name.split("/")[5]
_LOGGER.info(
f"Restored prompt version {version_id} under prompt id {prompt_id} as version number {new_version_id}"
)
# Step 2: Create PromptVersionMetadata object to return
return PromptVersionMetadata(
display_name=result.display_name,
prompt_id=result.name.split("/")[5],
version_id=new_version_id,
)
def get(prompt_id: str, version_id: Optional[str] = None) -> Prompt:
"""Creates a Prompt object from an online resource.
Args:
prompt_id: The id of the prompt resource.
version_id: Optional version id of the prompt resource.
If not specified, the latest version will be used.
Returns:
A prompt loaded from the online resource as a `Prompt` object.
"""
prompt = Prompt()
if version_id:
dataset = _get_prompt_resource_from_version(
prompt=prompt,
prompt_id=prompt_id,
version_id=version_id,
)
else:
dataset = _get_prompt_resource(prompt=prompt, prompt_id=prompt_id)
# Remove etag to avoid error for repeated dataset updates
dataset.etag = None
prompt._dataset = dataset
prompt._version_id = version_id
dataset_dict = _proto_to_dict(dataset)
metadata = PromptDatasetMetadata.from_dict(dataset_dict["metadata"])
_populate_fields_from_metadata(prompt=prompt, metadata=metadata)
return prompt
def _populate_fields_from_metadata(
prompt: Prompt, metadata: PromptDatasetMetadata
) -> None:
"""Helper function to populate Promptfields from metadata object"""
# Populate model_name (Required, raw deserialized type is str)
prompt.model_name = metadata.prompt_api_schema.prompt_message.model
# Populate prompt_data (raw deserialized type is list[Content])
contents = metadata.prompt_api_schema.prompt_message.contents
if contents:
if len(contents) > 1:
raise ValueError("Multi-turn prompts are not supported yet.")
prompt_data = [Part._from_gapic(part) for part in list(contents[0].parts)]
# Unwrap single text part into str
if len(prompt_data) == 1 and "text" in prompt_data[0]._raw_part:
prompt.prompt_data = prompt_data[0].text
else:
prompt.prompt_data = prompt_data
# Populate system_instruction (raw deserialized type is single Content)
system_instruction = metadata.prompt_api_schema.prompt_message.system_instruction
if system_instruction:
system_instruction_parts = [
Part._from_gapic(part) for part in list(system_instruction.parts)
]
# Unwrap single text part into str
if len(system_instruction_parts) == 1 and system_instruction_parts[0].text:
prompt.system_instruction = system_instruction_parts[0].text
else:
prompt.system_instruction = system_instruction_parts
# Populate variables
executions = metadata.prompt_api_schema.executions
variables = []
if executions:
for execution in executions:
serialized_variable_set = execution.arguments
variable_set = {}
if serialized_variable_set:
for name, value in serialized_variable_set.variables.items():
# Parts are dicts, not gapic objects for variables
variable_set[name] = [
Part.from_dict(part)
for part in list(value["partList"]["parts"])
]
variables.append(variable_set)
# Unwrap variable single text part into str
for variable_set in variables:
for name, value in variable_set.items():
if len(value) == 1 and "text" in value[0]._raw_part:
variable_set[name] = value[0].text
prompt.variables = variables
# Populate generation_config (raw deserialized type is GenerationConfig)
generation_config = metadata.prompt_api_schema.prompt_message.generation_config
if generation_config:
prompt.generation_config = generation_config
# Populate safety_settings (raw deserialized type is RepeatedComposite of SafetySetting)
safety_settings = metadata.prompt_api_schema.prompt_message.safety_settings
if safety_settings:
prompt.safety_settings = list(safety_settings)
# Populate tools (raw deserialized type is RepeatedComposite of Tool)
tools = metadata.prompt_api_schema.prompt_message.tools
if tools:
prompt.tools = list(tools)
# Populate tool_config (raw deserialized type is ToolConfig)
tool_config = metadata.prompt_api_schema.prompt_message.tool_config
if tool_config:
prompt.tool_config = ToolConfig._from_gapic(tool_config)
def list_prompts() -> list[PromptMetadata]:
"""Lists all prompt resources in the online prompt store associated with the project."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
parent = f"projects/{project}/locations/{location}"
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
prompts_pager = temp_prompt._dataset_client.list_datasets(
parent=parent,
)
prompts_list = []
for prompt in prompts_pager:
prompts_list.append(
PromptMetadata(
display_name=prompt.display_name,
prompt_id=prompt.name.split("/")[5],
)
)
return prompts_list
def list_versions(prompt_id: str) -> list[PromptVersionMetadata]:
"""Returns a list of PromptVersionMetadata objects for the prompt resource.
Args:
prompt_id: The id of the prompt resource to list versions of.
Returns:
A list of PromptVersionMetadata objects for the prompt resource.
"""
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
parent = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
versions_pager = temp_prompt._dataset_client.list_dataset_versions(
parent=parent,
)
version_history = []
for version in versions_pager:
version_history.append(
PromptVersionMetadata(
display_name=version.display_name,
prompt_id=version.name.split("/")[5],
version_id=version.name.split("/")[-1],
)
)
return version_history
def delete(prompt_id: str) -> None:
"""Deletes the online prompt resource associated with the prompt id."""
# Create a temporary Prompt object for a dataset client
temp_prompt = Prompt()
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
name = f"projects/{project}/locations/{location}/datasets/{prompt_id}"
operation = temp_prompt._dataset_client.delete_dataset(
name=name,
)
operation.result()
_LOGGER.info(f"Deleted prompt resource with id {prompt_id}.")

View File

@@ -0,0 +1,687 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from copy import deepcopy
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.services import dataset_service_client
from vertexai.generative_models import (
Content,
Image,
Part,
GenerativeModel,
GenerationConfig,
SafetySetting,
Tool,
ToolConfig,
)
from vertexai.generative_models._generative_models import (
_to_content,
_validate_generate_content_parameters,
_reconcile_model_name,
_get_resource_name_from_model_name,
ContentsType,
GenerationConfigType,
GenerationResponse,
PartsType,
SafetySettingsType,
)
import re
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Union,
)
_LOGGER = base.Logger(__name__)
DEFAULT_MODEL_NAME = "gemini-1.5-flash-002"
VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})"
class Prompt:
"""A prompt which may be a template with variables.
The `Prompt` class allows users to define a template string with
variables represented in curly braces `{variable}`. The variable
name must be a valid Python variable name (no spaces, must start with a
letter). These placeholders can be replaced with specific values using the
`assemble_contents` method, providing flexibility in generating dynamic prompts.
Usage:
Generate content from a single set of variables:
```
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
variables=[{"name": "Alice", "day": "Monday"}]
generation_config=GenerationConfig(
temperature=0.1,
top_p=0.95,
top_k=20,
candidate_count=1,
max_output_tokens=100,
),
model_name="gemini-1.0-pro-002",
safety_settings=[SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
method=SafetySetting.HarmBlockMethod.SEVERITY,
)],
system_instruction="Please answer in a short sentence.",
)
# Generate content using the assembled prompt.
prompt.generate_content(
contents=prompt.assemble_contents(**prompt.variables)
)
```
Generate content with multiple sets of variables:
```
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
variables=[
{"name": "Alice", "day": "Monday"},
{"name": "Bob", "day": "Tuesday"},
],
generation_config=GenerationConfig(
temperature=0.1,
top_p=0.95,
top_k=20,
candidate_count=1,
max_output_tokens=100,
),
model_name="gemini-1.0-pro-002",
safety_settings=[SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
method=SafetySetting.HarmBlockMethod.SEVERITY,
)],
system_instruction="Please answer in a short sentence.",
)
# Generate content using the assembled prompt for each variable set.
for i in range(len(prompt.variables)):
prompt.generate_content(
contents=prompt.assemble_contents(**prompt.variables[i])
)
```
"""
def __init__(
self,
prompt_data: Optional[PartsType] = None,
*,
variables: Optional[List[Dict[str, PartsType]]] = None,
prompt_name: Optional[str] = None,
generation_config: Optional[GenerationConfig] = None,
model_name: Optional[str] = None,
safety_settings: Optional[SafetySetting] = None,
system_instruction: Optional[PartsType] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
):
"""Initializes the Prompt with a given prompt, and variables.
Args:
prompt: A PartsType prompt which may be a template with variables or a prompt with no variables.
variables: A list of dictionaries containing the variable names and values.
prompt_name: The display name of the prompt, if stored in an online resource.
generation_config: A GenerationConfig object containing parameters for generation.
model_name: Model Garden model resource name.
Alternatively, a tuned model endpoint resource name can be provided.
If no model is provided, the default latest model will be used.
safety_settings: A SafetySetting object containing safety settings for generation.
system_instruction: A PartsType object representing the system instruction.
tools: A list of Tool objects for function calling.
tool_config: A ToolConfig object for function calling.
"""
self._prompt_data = None
self._variables = None
self._model_name = None
self._generation_config = None
self._safety_settings = None
self._system_instruction = None
self._tools = None
self._tool_config = None
# Prompt Management
self._dataset_client_value = None
self._dataset = None
self._prompt_name = None
self._version_id = None
self._version_name = None
self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
self.prompt_name = prompt_name
self.model_name = model_name
self.generation_config = generation_config
self.safety_settings = safety_settings
self.system_instruction = system_instruction
self.tools = tools
self.tool_config = tool_config
@property
def prompt_data(self) -> Optional[PartsType]:
return self._prompt_data
@property
def variables(self) -> Optional[List[Dict[str, PartsType]]]:
return self._variables
@property
def prompt_name(self) -> Optional[str]:
return self._prompt_name
@property
def generation_config(self) -> Optional[GenerationConfig]:
return self._generation_config
@property
def model_name(self) -> Optional[str]:
if self._model_name:
return self._model_name
else:
return Prompt._format_model_resource_name(DEFAULT_MODEL_NAME)
@property
def safety_settings(self) -> Optional[List[SafetySetting]]:
return self._safety_settings
@property
def system_instruction(self) -> Optional[PartsType]:
return self._system_instruction
@property
def tools(self) -> Optional[List[Tool]]:
return self._tools
@property
def tool_config(self) -> Optional[ToolConfig]:
return self._tool_config
@property
def prompt_id(self) -> Optional[str]:
if self._dataset:
return self._dataset.name.split("/")[-1]
return None
@property
def version_id(self) -> Optional[str]:
return self._version_id
@property
def version_name(self) -> Optional[str]:
return self._version_name
@prompt_data.setter
def prompt_data(self, prompt_data: Optional[PartsType]) -> None:
"""Overwrites the existing saved local prompt_data.
Args:
prompt_data: A PartsType prompt.
"""
if prompt_data is not None:
self._validate_parts_type_data(prompt_data)
self._prompt_data = prompt_data
@variables.setter
def variables(self, variables: List[Dict[str, PartsType]]) -> None:
"""Overwrites the existing saved local variables.
Args:
variables: A list of dictionaries containing the variable names and values.
"""
if isinstance(variables, list):
for i in range(len(variables)):
variables[i] = variables[i].copy()
Prompt._format_variable_value_to_parts(variables[i])
self._variables = variables
else:
raise TypeError(
f"Variables must be a list of dictionaries, not {type(variables)}"
)
@prompt_name.setter
def prompt_name(self, prompt_name: Optional[str]) -> None:
"""Overwrites the existing saved local prompt_name."""
if prompt_name:
self._prompt_name = prompt_name
else:
self._prompt_name = None
@model_name.setter
def model_name(self, model_name: Optional[str]) -> None:
"""Overwrites the existing saved local model_name."""
if model_name:
self._model_name = Prompt._format_model_resource_name(model_name)
else:
self._model_name = None
def _format_model_resource_name(model_name: Optional[str]) -> str:
"""Formats the model resource name."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
model_name = _reconcile_model_name(model_name, project, location)
prediction_resource_name = _get_resource_name_from_model_name(
model_name, project, location
)
return prediction_resource_name
def _validate_configs(
self,
generation_config: Optional[GenerationConfig] = None,
safety_settings: Optional[SafetySetting] = None,
system_instruction: Optional[PartsType] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
):
generation_config = generation_config or self._generation_config
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools
tool_config = tool_config or self._tool_config
system_instruction = system_instruction or self._system_instruction
return _validate_generate_content_parameters(
contents="test",
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=tools,
tool_config=tool_config,
)
@generation_config.setter
def generation_config(self, generation_config: Optional[GenerationConfig]) -> None:
"""Overwrites the existing saved local generation_config.
Args:
generation_config: A GenerationConfig object containing parameters for generation.
"""
self._validate_configs(generation_config=generation_config)
self._generation_config = generation_config
@safety_settings.setter
def safety_settings(self, safety_settings: Optional[SafetySetting]) -> None:
"""Overwrites the existing saved local safety_settings.
Args:
safety_settings: A SafetySetting object containing safety settings for generation.
"""
self._validate_configs(safety_settings=safety_settings)
self._safety_settings = safety_settings
@system_instruction.setter
def system_instruction(self, system_instruction: Optional[PartsType]) -> None:
"""Overwrites the existing saved local system_instruction.
Args:
system_instruction: A PartsType object representing the system instruction.
"""
if system_instruction:
self._validate_parts_type_data(system_instruction)
self._system_instruction = system_instruction
@tools.setter
def tools(self, tools: Optional[List[Tool]]) -> None:
"""Overwrites the existing saved local tools.
Args:
tools: A list of Tool objects for function calling.
"""
self._validate_configs(tools=tools)
self._tools = tools
@tool_config.setter
def tool_config(self, tool_config: Optional[ToolConfig] = None) -> None:
"""Overwrites the existing saved local tool_config.
Args:
tool_config: A ToolConfig object for function calling.
"""
self._validate_configs(tool_config=tool_config)
self._tool_config = tool_config
def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> None:
"""Formats the variables values to be List[Part].
Args:
variables_dict: A single dictionary containing the variable names and values.
Raises:
TypeError: If a variable value is not a PartsType Object.
"""
for key in variables_dict.keys():
# Disallow Content as variable value.
if isinstance(variables_dict[key], Content):
raise TypeError(
"Variable values must be a PartsType object, not Content"
)
# Rely on type checks in _to_content for validation.
content = Content._from_gapic(_to_content(value=variables_dict[key]))
variables_dict[key] = content.parts
def _validate_parts_type_data(self, data: Any) -> None:
"""
Args:
prompt_data: The prompt input to validate
Raises:
TypeError: If prompt_data is not a PartsType Object.
"""
# Disallow Content as prompt_data.
if isinstance(data, Content):
raise TypeError("Prompt data must be a PartsType object, not Content")
# Rely on type checks in _to_content.
_to_content(value=data)
def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
"""Returns the prompt data, as a List[Content], assembled with variables if applicable.
Can be ingested into model.generate_content to make API calls.
Returns:
A List[Content] prompt.
Usage:
```
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
)
model.generate_content(
contents=prompt.assemble_contents(name="Alice", day="Monday")
)
```
"""
# If prompt_data is None, throw an error.
if self.prompt_data is None:
raise ValueError("prompt_data must not be empty.")
variables_dict = variables_dict.copy()
# If there are no variables, return the prompt_data as a Content object.
if not variables_dict:
return [Content._from_gapic(_to_content(value=self.prompt_data))]
# Step 1) Convert the variables values to List[Part].
Prompt._format_variable_value_to_parts(variables_dict)
# Step 2) Assemble the prompt.
# prompt_data must have been previously validated using _validate_parts_type_data.
assembled_prompt = []
assembled_variables_cnt = {}
if isinstance(self.prompt_data, list):
# User inputted a List of Parts as prompt_data.
for part in self.prompt_data:
assembled_prompt.extend(
self._assemble_singular_part(
part, variables_dict, assembled_variables_cnt
)
)
else:
# User inputted a single str, Image, or Part as prompt_data.
assembled_prompt.extend(
self._assemble_singular_part(
self.prompt_data, variables_dict, assembled_variables_cnt
)
)
# Step 3) Simplify adjacent string Parts
simplified_assembled_prompt = [assembled_prompt[0]]
for i in range(1, len(assembled_prompt)):
# If the previous Part and current Part is a string, concatenate them.
try:
prev_text = simplified_assembled_prompt[-1].text
curr_text = assembled_prompt[i].text
if isinstance(prev_text, str) and isinstance(curr_text, str):
simplified_assembled_prompt[-1] = Part.from_text(
prev_text + curr_text
)
else:
simplified_assembled_prompt.append(assembled_prompt[i])
except AttributeError:
simplified_assembled_prompt.append(assembled_prompt[i])
continue
# Step 4) Validate that all variables were used, if specified.
for key in variables_dict:
if key not in assembled_variables_cnt:
raise ValueError(f"Variable {key} is not present in prompt_data.")
assemble_cnt_msg = "Assembled prompt replacing: "
for key in assembled_variables_cnt:
assemble_cnt_msg += (
f"{assembled_variables_cnt[key]} instances of variable {key}, "
)
if assemble_cnt_msg[-2:] == ", ":
assemble_cnt_msg = assemble_cnt_msg[:-2]
_LOGGER.info(assemble_cnt_msg)
# Step 5) Wrap List[Part] as a single Content object.
return [
Content(
parts=simplified_assembled_prompt,
role="user",
)
]
def _assemble_singular_part(
self,
prompt_data_part: Union[str, Image, Part],
formatted_variables_set: Dict[str, List[Part]],
assembled_variables_cnt: Dict[str, int],
) -> List[Part]:
"""Assemble a str, Image, or Part."""
if isinstance(prompt_data_part, Image):
# Templating is not supported for Image prompt_data.
return [Part.from_image(prompt_data_part)]
elif isinstance(prompt_data_part, str):
# Assemble a single string
return self._assemble_single_str(
prompt_data_part, formatted_variables_set, assembled_variables_cnt
)
elif isinstance(prompt_data_part, Part):
# If the Part is a text Part, assemble it.
try:
text = prompt_data_part.text
except AttributeError:
return [prompt_data_part]
return self._assemble_single_str(
text, formatted_variables_set, assembled_variables_cnt
)
def _assemble_single_str(
self,
prompt_data_str: str,
formatted_variables_set: Dict[str, List[Part]],
assembled_variables_cnt: Dict[str, int],
) -> List[Part]:
"""Assemble a single string with 0 or more variables within the string."""
# Step 1) Find and isolate variables as their own string.
prompt_data_str_split = re.split(VARIABLE_NAME_REGEX, prompt_data_str)
assembled_data = []
# Step 2) Assemble variables with their values, creating a list of Parts.
for s in prompt_data_str_split:
if not s:
continue
variable_name = s[1:-1]
if (
re.match(VARIABLE_NAME_REGEX, s)
and variable_name in formatted_variables_set
):
assembled_data.extend(formatted_variables_set[variable_name])
assembled_variables_cnt[variable_name] = (
assembled_variables_cnt.get(variable_name, 0) + 1
)
else:
assembled_data.append(Part.from_text(s))
return assembled_data
def generate_content(
self,
contents: ContentsType,
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
model_name: Optional[str] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
system_instruction: Optional[PartsType] = None,
) -> Union["GenerationResponse", Iterable["GenerationResponse"],]:
"""Generates content using the saved Prompt configs.
Args:
contents: Contents to send to the model.
Supports either a list of Content objects (passing a multi-turn conversation)
or a value that can be converted to a single Content object (passing a single message).
Supports
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
generation_config: Parameters for the generation.
model_name: Prediction model resource name.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
stream: Whether to stream the response.
system_instruction: System instruction to pass to the model.
Returns:
A single GenerationResponse object if stream == False
A stream of GenerationResponse objects if stream == True
Usage:
```
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
variables={"name": "Alice", "day": "Monday"},
generation_config=GenerationConfig(temperature=0.1,),
system_instruction="Please answer in a short sentence.",
model_name="gemini-1.0-pro-002",
)
prompt.generate_content(
contents=prompt.assemble_contents(**prompt.variables)
)
```
"""
if not (model_name or self._model_name):
_LOGGER.info(
"No model name specified, falling back to default model: %s",
self.model_name,
)
model_name = model_name or self.model_name
generation_config = generation_config or self.generation_config
safety_settings = safety_settings or self.safety_settings
tools = tools or self.tools
tool_config = tool_config or self.tool_config
system_instruction = system_instruction or self.system_instruction
if not model_name:
raise ValueError(
"Model name must be specified to use Prompt.generate_content()"
)
model_name = Prompt._format_model_resource_name(model_name)
model = GenerativeModel(
model_name=model_name, system_instruction=system_instruction
)
return model.generate_content(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
stream=stream,
)
@property
def _dataset_client(self) -> dataset_service_client.DatasetServiceClient:
if not getattr(self, "_dataset_client_value", None):
self._dataset_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=dataset_service_client.DatasetServiceClient,
)
)
return self._dataset_client_value
@classmethod
def _clone(cls, prompt: "Prompt") -> "Prompt":
"""Returns a copy of the Prompt."""
return Prompt(
prompt_data=deepcopy(prompt.prompt_data),
variables=deepcopy(prompt.variables),
generation_config=deepcopy(prompt.generation_config),
safety_settings=deepcopy(prompt.safety_settings),
tools=deepcopy(prompt.tools),
tool_config=deepcopy(prompt.tool_config),
system_instruction=deepcopy(prompt.system_instruction),
model_name=prompt.model_name,
)
def get_unassembled_prompt_data(self) -> PartsType:
"""Returns the prompt data, without any variables replaced."""
return self.prompt_data
def __str__(self) -> str:
"""Returns the prompt data as a string, without any variables replaced."""
return str(self.prompt_data or "")
def __repr__(self) -> str:
"""Returns a string representation of the unassembled prompt."""
result = "Prompt("
if self.prompt_data:
result += f"prompt_data='{self.prompt_data}', "
if self.variables and self.variables[0]:
result += f"variables={self.variables}), "
if self.system_instruction:
result += f"system_instruction={self.system_instruction}), "
if self._model_name:
# Don't display default model in repr
result += f"model_name={self._model_name}), "
if self.generation_config:
result += f"generation_config={self.generation_config}), "
if self.safety_settings:
result += f"safety_settings={self.safety_settings}), "
if self.tools:
result += f"tools={self.tools}), "
if self.tool_config:
result += f"tool_config={self.tool_config}, "
if self.prompt_id:
result += f"prompt_id={self.prompt_id}, "
if self.version_id:
result += f"version_id={self.version_id}, "
if self.prompt_name:
result += f"prompt_name={self.prompt_name}, "
if self.version_name:
result += f"version_name={self.version_name}, "
# Remove trailing ", "
if result[-2:] == ", ":
result = result[:-2]
result += ")"
return result