392 lines
15 KiB
Python
392 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import json
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
from google.cloud.aiplatform import base
|
|
from google.cloud.aiplatform import initializer
|
|
from google.cloud.aiplatform import utils as aip_utils
|
|
from google.cloud.aiplatform_v1beta1 import types
|
|
from vertexai.generative_models import _generative_models
|
|
from vertexai.reasoning_engines import _utils
|
|
from google.protobuf import struct_pb2
|
|
|
|
_LOGGER = base.Logger(__name__)
|
|
|
|
_AuthConfigOrJson = Union[_utils.JsonDict, types.AuthConfig]
|
|
_StructOrJson = Union[_utils.JsonDict, struct_pb2.Struct]
|
|
_RuntimeConfigOrJson = Union[_utils.JsonDict, types.RuntimeConfig]
|
|
|
|
|
|
_VERTEX_EXTENSION_HUB = {
|
|
"code_interpreter": {
|
|
"display_name": "Code Interpreter",
|
|
"description": (
|
|
"This extension generates and executes code in the specified language"
|
|
),
|
|
"manifest": {
|
|
"name": "code_interpreter_tool",
|
|
"description": "Google Code Interpreter Extension",
|
|
"api_spec": {
|
|
"open_api_gcs_uri": (
|
|
"gs://vertex-extension-public/code_interpreter.yaml"
|
|
),
|
|
},
|
|
"auth_config": {
|
|
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
|
"google_service_account_config": {},
|
|
},
|
|
},
|
|
},
|
|
"vertex_ai_search": {
|
|
"display_name": "Vertex AI Search",
|
|
"description": "This extension generates and executes search queries",
|
|
"manifest": {
|
|
"name": "vertex_ai_search",
|
|
"description": "Vertex AI Search Extension",
|
|
"api_spec": {
|
|
"open_api_gcs_uri": (
|
|
"gs://vertex-extension-public/vertex_ai_search.yaml"
|
|
),
|
|
},
|
|
"auth_config": {
|
|
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
|
"google_service_account_config": {},
|
|
},
|
|
},
|
|
},
|
|
"webpage_browser": {
|
|
"display_name": "Webpage Browser",
|
|
"description": "This extension fetches the content of a webpage",
|
|
"manifest": {
|
|
"name": "webpage_browser",
|
|
"description": "Vertex Webpage Browser Extension",
|
|
"api_spec": {
|
|
"open_api_gcs_uri": (
|
|
"gs://vertex-extension-public/webpage_browser.yaml"
|
|
),
|
|
},
|
|
"auth_config": {
|
|
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
|
"google_service_account_config": {},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
class Extension(base.VertexAiResourceNounWithFutureManager):
|
|
"""Represents a Vertex AI Extension resource."""
|
|
|
|
client_class = aip_utils.ExtensionRegistryClientWithOverride
|
|
_resource_noun = "extension"
|
|
_getter_method = "get_extension"
|
|
_list_method = "list_extensions"
|
|
_delete_method = "delete_extension"
|
|
_parse_resource_name_method = "parse_extension_path"
|
|
_format_resource_name_method = "extension_path"
|
|
|
|
def __init__(self, extension_name: str):
|
|
"""Retrieves an extension resource.
|
|
|
|
Args:
|
|
extension_name (str):
|
|
Required. A fully-qualified resource name or ID such as
|
|
"projects/123/locations/us-central1/extensions/456" or
|
|
"456" when project and location are initialized or passed.
|
|
"""
|
|
super().__init__(resource_name=extension_name)
|
|
self.execution_api_client = initializer.global_config.create_client(
|
|
client_class=aip_utils.ExtensionExecutionClientWithOverride,
|
|
)
|
|
self._gca_resource = self._get_gca_resource(resource_name=extension_name)
|
|
self._api_spec = None
|
|
self._operation_schemas = None
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
manifest: Union[_utils.JsonDict, types.ExtensionManifest],
|
|
*,
|
|
extension_name: Optional[str] = None,
|
|
display_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
runtime_config: Optional[_RuntimeConfigOrJson] = None,
|
|
):
|
|
"""Creates a new Extension.
|
|
|
|
Args:
|
|
manifest (Union[dict[str, Any], ExtensionManifest]):
|
|
Required. The manifest for the Extension to be created.
|
|
extension_name (str):
|
|
Optional. A fully-qualified extension resource name or extension
|
|
ID such as "projects/123/locations/us-central1/extensions/456" or
|
|
"456" when project and location are initialized or passed. If
|
|
specifying the extension ID, it should be 4-63 characters, valid
|
|
characters are lowercase letters, numbers and hyphens ("-"),
|
|
and it should start with a number or a lower-case letter. If not
|
|
provided, Vertex AI will generate a value for this ID.
|
|
display_name (str):
|
|
Optional. The user-defined name of the Extension.
|
|
The name can be up to 128 characters long and can comprise any
|
|
UTF-8 character.
|
|
description (str):
|
|
Optional. The description of the Extension.
|
|
runtime_config (Union[dict[str, Any], RuntimeConfig]):
|
|
Optional. Runtime config controlling the runtime behavior of
|
|
this Extension. Defaults to None.
|
|
|
|
Returns:
|
|
Extension: The extension that was created.
|
|
"""
|
|
sdk_resource = cls.__new__(cls)
|
|
base.VertexAiResourceNounWithFutureManager.__init__(
|
|
sdk_resource,
|
|
resource_name=extension_name,
|
|
)
|
|
extension = types.Extension(
|
|
name=extension_name,
|
|
display_name=display_name or cls._generate_display_name(),
|
|
description=description,
|
|
manifest=_utils.to_proto(manifest, types.ExtensionManifest()),
|
|
)
|
|
if runtime_config:
|
|
extension.runtime_config = _utils.to_proto(
|
|
runtime_config,
|
|
types.RuntimeConfig(),
|
|
)
|
|
operation_future = sdk_resource.api_client.import_extension(
|
|
parent=initializer.global_config.common_location_path(),
|
|
extension=extension,
|
|
)
|
|
_LOGGER.log_create_with_lro(cls, operation_future)
|
|
created_extension = operation_future.result()
|
|
_LOGGER.log_create_complete(
|
|
cls,
|
|
created_extension,
|
|
cls._resource_noun,
|
|
module_name="vertexai.preview.extensions",
|
|
)
|
|
# We use `._get_gca_resource(...)` instead of `created_extension` to
|
|
# fully instantiate the attributes of the extension.
|
|
sdk_resource._gca_resource = sdk_resource._get_gca_resource(
|
|
resource_name=created_extension.name
|
|
)
|
|
sdk_resource.execution_api_client = initializer.global_config.create_client(
|
|
client_class=aip_utils.ExtensionExecutionClientWithOverride,
|
|
)
|
|
sdk_resource._api_spec = None
|
|
sdk_resource._operation_schemas = None
|
|
return sdk_resource
|
|
|
|
@property
|
|
def resource_name(self) -> str:
|
|
"""Full qualified resource name for the extension."""
|
|
return self._gca_resource.name
|
|
|
|
def api_spec(self) -> _utils.JsonDict:
|
|
"""Returns the (Open)API Spec of the extension."""
|
|
if self._api_spec is None:
|
|
self._api_spec = _load_api_spec(self._gca_resource.manifest.api_spec)
|
|
return self._api_spec
|
|
|
|
def operation_schemas(self) -> Sequence[_utils.JsonDict]:
|
|
"""Returns the (Open)API schemas for each operation of the extension."""
|
|
if self._operation_schemas is None:
|
|
self._operation_schemas = [
|
|
_utils.to_dict(op.function_declaration)
|
|
for op in self._gca_resource.extension_operations
|
|
]
|
|
return self._operation_schemas
|
|
|
|
def execute(
|
|
self,
|
|
operation_id: str,
|
|
operation_params: Optional[_StructOrJson] = None,
|
|
runtime_auth_config: Optional[_AuthConfigOrJson] = None,
|
|
) -> Union[_utils.JsonDict, str]:
|
|
"""Executes an operation of the extension with the specified params.
|
|
|
|
Args:
|
|
operation_id (str):
|
|
Required. The ID of the operation to be executed.
|
|
operation_params (Union[dict[str, Any], Struct]):
|
|
Optional. Parameters used for executing the operation. It should
|
|
be in a form of map with param name as the key and actual param
|
|
value as the value. E.g. if this operation requires a param
|
|
"name" to be set to "abc", you can set this to {"name": "abc"}.
|
|
Defaults to an empty dictionary.
|
|
runtime_auth_config (Union[dict[str, Any], AuthConfig]):
|
|
Optional. The Auth configuration to execute the operation.
|
|
|
|
Returns:
|
|
The result of executing the extension operation.
|
|
"""
|
|
request = types.ExecuteExtensionRequest(
|
|
name=self.resource_name,
|
|
operation_id=operation_id,
|
|
operation_params=operation_params,
|
|
)
|
|
if runtime_auth_config:
|
|
request.runtime_auth_config = _utils.to_proto(
|
|
runtime_auth_config,
|
|
types.AuthConfig(),
|
|
)
|
|
response = self.execution_api_client.execute_extension(request)
|
|
return _try_parse_execution_response(response)
|
|
|
|
def query(
|
|
self,
|
|
contents: _generative_models.ContentsType,
|
|
) -> "QueryExtensionResponse":
|
|
"""Queries an extension with the specified contents.
|
|
|
|
Args:
|
|
contents (ContentsType):
|
|
Required. The content of the current
|
|
conversation with the model.
|
|
For single-turn queries, this is a single
|
|
instance. For multi-turn queries, this is a
|
|
repeated field that contains conversation
|
|
history + latest request.
|
|
|
|
Returns:
|
|
The result of querying the extension.
|
|
|
|
Raises:
|
|
RuntimeError: If the response contains an error.
|
|
"""
|
|
request = types.QueryExtensionRequest(
|
|
name=self.resource_name,
|
|
contents=_generative_models._content_types_to_gapic_contents(contents),
|
|
)
|
|
response = self.execution_api_client.query_extension(request)
|
|
if response.failure_message:
|
|
raise RuntimeError(response.failure_message)
|
|
return QueryExtensionResponse._from_gapic(response)
|
|
|
|
@classmethod
|
|
def from_hub(
|
|
cls,
|
|
name: str,
|
|
*,
|
|
runtime_config: Optional[_RuntimeConfigOrJson] = None,
|
|
):
|
|
"""Creates a new Extension from the set of first party extensions.
|
|
|
|
Args:
|
|
name (str):
|
|
Required. The name of the extension in the hub to be created.
|
|
Supported values are "code_interpreter", "vertex_ai_search" and
|
|
"webpage_browser".
|
|
runtime_config (Union[dict[str, Any], RuntimeConfig]):
|
|
Optional. Runtime config controlling the runtime behavior of
|
|
the Extension. Defaults to None.
|
|
|
|
Returns:
|
|
Extension: The extension that was created.
|
|
|
|
Raises:
|
|
ValueError: If the `name` is not supported in the hub.
|
|
ValueError: If the `runtime_config` is specified but inconsistent
|
|
with the name (e.g. the name was "code_interpreter" but the
|
|
runtime_config was based on "vertex_ai_search_runtime_config").
|
|
"""
|
|
if runtime_config:
|
|
runtime_config = _utils.to_proto(
|
|
runtime_config,
|
|
types.RuntimeConfig(),
|
|
)
|
|
if name == "code_interpreter":
|
|
if runtime_config and not getattr(
|
|
runtime_config,
|
|
"code_interpreter_runtime_config",
|
|
None,
|
|
):
|
|
raise ValueError(
|
|
"code_interpreter_runtime_config is required for "
|
|
"code_interpreter extension"
|
|
)
|
|
elif name == "vertex_ai_search":
|
|
if not runtime_config:
|
|
raise ValueError(
|
|
"runtime_config is required for vertex_ai_search extension"
|
|
)
|
|
if runtime_config and not getattr(
|
|
runtime_config,
|
|
"vertex_ai_search_runtime_config",
|
|
None,
|
|
):
|
|
raise ValueError(
|
|
"vertex_ai_search_runtime_config is required for "
|
|
"vertex_ai_search extension"
|
|
)
|
|
elif name == "webpage_browser":
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unsupported 1P extension name: {name}")
|
|
extension_info = _VERTEX_EXTENSION_HUB[name]
|
|
return cls.create(
|
|
display_name=extension_info["display_name"],
|
|
description=extension_info["description"],
|
|
manifest=extension_info["manifest"],
|
|
runtime_config=runtime_config,
|
|
)
|
|
|
|
|
|
class QueryExtensionResponse:
|
|
"""A class representing the response from querying an extension."""
|
|
|
|
def __init__(self, steps: List[_generative_models.Content]):
|
|
"""Initializes the QueryExtensionResponse with the given steps."""
|
|
self.steps = steps
|
|
|
|
@classmethod
|
|
def _from_gapic(
|
|
cls, response: types.QueryExtensionResponse
|
|
) -> "QueryExtensionResponse":
|
|
"""Creates a QueryExtensionResponse from a gapic response."""
|
|
return cls(
|
|
steps=[
|
|
_generative_models.Content(
|
|
parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
|
|
role=c.role,
|
|
)
|
|
for c in response.steps
|
|
]
|
|
)
|
|
|
|
|
|
def _try_parse_execution_response(
|
|
response: types.ExecuteExtensionResponse,
|
|
) -> Union[_utils.JsonDict, str]:
|
|
content: str = response.content
|
|
try:
|
|
content = json.loads(content)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return content
|
|
|
|
|
|
def _load_api_spec(api_spec) -> _utils.JsonDict:
|
|
"""Loads the (Open)API Spec of the extension and converts it to JSON."""
|
|
if api_spec.open_api_yaml:
|
|
yaml = aip_utils.yaml_utils._maybe_import_yaml()
|
|
return yaml.safe_load(api_spec.open_api_yaml)
|
|
elif api_spec.open_api_gcs_uri:
|
|
return aip_utils.yaml_utils.load_yaml(api_spec.open_api_gcs_uri)
|
|
return {}
|