# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import datetime import json from typing import List, Optional from typing_extensions import override from google.cloud.aiplatform import base as aiplatform_base from google.cloud.aiplatform import initializer as aiplatform_initializer from google.cloud.aiplatform import utils as aiplatform_utils from google.cloud.aiplatform.compat.types import ( cached_content_v1beta1 as gca_cached_content, ) from google.cloud.aiplatform_v1.services import ( gen_ai_cache_service as gen_ai_cache_service_v1, ) from google.cloud.aiplatform_v1beta1.types.cached_content import ( CachedContent as GapicCachedContent, ) from google.cloud.aiplatform_v1beta1.types import ( content as gapic_content_types, ) from google.cloud.aiplatform_v1beta1.types.gen_ai_cache_service import ( CreateCachedContentRequest, GetCachedContentRequest, UpdateCachedContentRequest, ) from google.cloud.aiplatform_v1 import types as types_v1 from vertexai.generative_models import _generative_models from vertexai.generative_models._generative_models import ( Content, PartsType, Tool, ToolConfig, ContentsType, ) from google.protobuf import field_mask_pb2 def _prepare_create_request( model_name: str, *, system_instruction: Optional[PartsType] = None, tools: Optional[List[Tool]] = None, tool_config: Optional[ToolConfig] = None, contents: Optional[ContentsType] = None, expire_time: Optional[datetime.datetime] = None, ttl: Optional[datetime.timedelta] = None, display_name: Optional[str] = None, ) -> CreateCachedContentRequest: """Prepares the request create_cached_content RPC.""" ( project, location, ) = aiplatform_initializer.global_config._get_default_project_and_location() if contents: _generative_models._validate_contents_type_as_valid_sequence(contents) if tools: _generative_models._validate_tools_type_as_valid_sequence(tools) if tool_config: _generative_models._validate_tool_config_type(tool_config) # contents can either be a list of Content objects (most generic case) if contents: contents = _generative_models._content_types_to_gapic_contents(contents) gapic_system_instruction: Optional[gapic_content_types.Content] = None if system_instruction: gapic_system_instruction = _generative_models._to_content(system_instruction) gapic_tools = None if tools: gapic_tools = _generative_models._tool_types_to_gapic_tools(tools) gapic_tool_config = None if tool_config: gapic_tool_config = tool_config._gapic_tool_config if ttl and expire_time: raise ValueError("Only one of ttl and expire_time can be set.") request_v1beta1 = CreateCachedContentRequest( parent=f"projects/{project}/locations/{location}", cached_content=GapicCachedContent( model=model_name, system_instruction=gapic_system_instruction, tools=gapic_tools, tool_config=gapic_tool_config, contents=contents, expire_time=expire_time, ttl=ttl, display_name=display_name, ), ) serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1) try: request_v1 = types_v1.CreateCachedContentRequest.deserialize( serialized_message_v1beta1 ) except Exception as ex: raise ValueError( "Failed to convert CreateCachedContentRequest from v1beta1 to v1:\n" f"{serialized_message_v1beta1}" ) from ex return request_v1 def _prepare_get_cached_content_request(name: str) -> GetCachedContentRequest: return types_v1.GetCachedContentRequest(name=name) class CachedContent(aiplatform_base._VertexAiResourceNounPlus): """A cached content resource.""" _resource_noun = "cachedContent" _getter_method = "get_cached_content" _list_method = "list_cached_contents" _delete_method = "delete_cached_content" _parse_resource_name_method = "parse_cached_content_path" _format_resource_name_method = "cached_content_path" client_class = aiplatform_utils.GenAiCacheServiceClientWithOverride _gen_ai_cache_service_client_value: Optional[ gen_ai_cache_service_v1.GenAiCacheServiceClient ] = None def __init__(self, cached_content_name: str): """Represents a cached content resource. This resource can be used with vertexai.generative_models.GenerativeModel to cache the prefix so it can be used across multiple generate_content requests. Args: cached_content_name (str): Required. The name of the cached content resource. It could be a fully-qualified CachedContent resource name or a CachedContent ID. Example: "projects/.../locations/../cachedContents/456" or "456". """ super().__init__(resource_name=cached_content_name) self._gca_resource = self._get_gca_resource(cached_content_name) @property def _raw_cached_content(self) -> gca_cached_content.CachedContent: return self._gca_resource @property def model_name(self) -> str: return self._gca_resource.model @classmethod def create( cls, *, model_name: str, system_instruction: Optional[Content] = None, tools: Optional[List[Tool]] = None, tool_config: Optional[ToolConfig] = None, contents: Optional[List[Content]] = None, expire_time: Optional[datetime.datetime] = None, ttl: Optional[datetime.timedelta] = None, display_name: Optional[str] = None, ) -> "CachedContent": """Creates a new cached content through the gen ai cache service. Usage: Args: model: Immutable. The name of the publisher model to use for cached content. Allowed formats: projects/{project}/locations/{location}/publishers/{publisher}/models/{model}, or publishers/{publisher}/models/{model}, or a single model name. system_instruction: Optional. Immutable. Developer-set system instruction. Currently, text only. contents: Optional. Immutable. The content to cache as a list of Content. tools: Optional. Immutable. A list of ``Tools`` the model may use to generate the next response. tool_config: Optional. Immutable. Tool config. This config is shared for all tools. expire_time: Timestamp of when this resource is considered expired. At most one of expire_time and ttl can be set. If neither is set, default TTL on the API side will be used (currently 1 hour). ttl: The TTL for this resource. If provided, the expiration time is computed: created_time + TTL. At most one of expire_time and ttl can be set. If neither is set, default TTL on the API side will be used (currently 1 hour). display_name: The user-generated meaningful display name of the cached content. Returns: A CachedContent object with only name and model_name specified. Raises: ValueError: If both expire_time and ttl are set. """ project = aiplatform_initializer.global_config.project location = aiplatform_initializer.global_config.location if model_name.startswith("publishers/"): model_name = f"projects/{project}/locations/{location}/{model_name}" elif not model_name.startswith("projects/"): model_name = f"projects/{project}/locations/{location}/publishers/google/models/{model_name}" if ttl and expire_time: raise ValueError("Only one of ttl and expire_time can be set.") request = _prepare_create_request( model_name=model_name, system_instruction=system_instruction, tools=tools, tool_config=tool_config, contents=contents, expire_time=expire_time, ttl=ttl, display_name=display_name, ) client = cls._instantiate_client(location=location) cached_content_resource = client.create_cached_content(request) obj = cls(cached_content_resource.name) obj._gca_resource = cached_content_resource return obj def refresh(self): """Syncs the local cached content with the remote resource.""" self._sync_gca_resource() def update( self, *, expire_time: Optional[datetime.datetime] = None, ttl: Optional[datetime.timedelta] = None, ): """Updates the expiration time of the cached content.""" if expire_time and ttl: raise ValueError("Only one of ttl and expire_time can be set.") update_mask: List[str] = [] if ttl: update_mask.append("ttl") if expire_time: update_mask.append("expire_time") update_mask = field_mask_pb2.FieldMask(paths=update_mask) request_v1beta1 = UpdateCachedContentRequest( cached_content=GapicCachedContent( name=self.resource_name, expire_time=expire_time, ttl=ttl, ), update_mask=update_mask, ) serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1) try: request_v1 = types_v1.UpdateCachedContentRequest.deserialize( serialized_message_v1beta1 ) except Exception as ex: raise ValueError( "Failed to convert UpdateCachedContentRequest from v1beta1 to v1:\n" f"{serialized_message_v1beta1}" ) from ex self.api_client.update_cached_content(request_v1) @property def expire_time(self) -> datetime.datetime: """Time this resource is considered expired. The returned value may be stale. Use refresh() to get the latest value. Returns: The expiration time of the cached content resource. """ return self._gca_resource.expire_time def delete(self): """Deletes the current cached content resource.""" self._delete() @override def __repr__(self) -> str: return f"{object.__repr__(self)}: {json.dumps(self.to_dict(), indent=2)}" @classmethod def list(cls) -> List["CachedContent"]: """Lists the active cached content resources.""" # TODO(b/345326114): Make list() interface richer after aligning with # Google AI SDK return cls._list() @classmethod def get(cls, cached_content_name: str) -> "CachedContent": """Retrieves an existing cached content resource.""" cache = cls(cached_content_name) return cache @override @property def display_name(self) -> str: """Display name of this resource.""" return self._gca_resource.display_name