Files
evo-ai/.venv/lib/python3.10/site-packages/vertexai/extensions/_extensions.py
2025-04-25 15:30:54 -03:00

392 lines
15 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from typing import List, Optional, Sequence, Union
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1beta1 import types
from vertexai.generative_models import _generative_models
from vertexai.reasoning_engines import _utils
from google.protobuf import struct_pb2
_LOGGER = base.Logger(__name__)
_AuthConfigOrJson = Union[_utils.JsonDict, types.AuthConfig]
_StructOrJson = Union[_utils.JsonDict, struct_pb2.Struct]
_RuntimeConfigOrJson = Union[_utils.JsonDict, types.RuntimeConfig]
_VERTEX_EXTENSION_HUB = {
"code_interpreter": {
"display_name": "Code Interpreter",
"description": (
"This extension generates and executes code in the specified language"
),
"manifest": {
"name": "code_interpreter_tool",
"description": "Google Code Interpreter Extension",
"api_spec": {
"open_api_gcs_uri": (
"gs://vertex-extension-public/code_interpreter.yaml"
),
},
"auth_config": {
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
"google_service_account_config": {},
},
},
},
"vertex_ai_search": {
"display_name": "Vertex AI Search",
"description": "This extension generates and executes search queries",
"manifest": {
"name": "vertex_ai_search",
"description": "Vertex AI Search Extension",
"api_spec": {
"open_api_gcs_uri": (
"gs://vertex-extension-public/vertex_ai_search.yaml"
),
},
"auth_config": {
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
"google_service_account_config": {},
},
},
},
"webpage_browser": {
"display_name": "Webpage Browser",
"description": "This extension fetches the content of a webpage",
"manifest": {
"name": "webpage_browser",
"description": "Vertex Webpage Browser Extension",
"api_spec": {
"open_api_gcs_uri": (
"gs://vertex-extension-public/webpage_browser.yaml"
),
},
"auth_config": {
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
"google_service_account_config": {},
},
},
},
}
class Extension(base.VertexAiResourceNounWithFutureManager):
"""Represents a Vertex AI Extension resource."""
client_class = aip_utils.ExtensionRegistryClientWithOverride
_resource_noun = "extension"
_getter_method = "get_extension"
_list_method = "list_extensions"
_delete_method = "delete_extension"
_parse_resource_name_method = "parse_extension_path"
_format_resource_name_method = "extension_path"
def __init__(self, extension_name: str):
"""Retrieves an extension resource.
Args:
extension_name (str):
Required. A fully-qualified resource name or ID such as
"projects/123/locations/us-central1/extensions/456" or
"456" when project and location are initialized or passed.
"""
super().__init__(resource_name=extension_name)
self.execution_api_client = initializer.global_config.create_client(
client_class=aip_utils.ExtensionExecutionClientWithOverride,
)
self._gca_resource = self._get_gca_resource(resource_name=extension_name)
self._api_spec = None
self._operation_schemas = None
@classmethod
def create(
cls,
manifest: Union[_utils.JsonDict, types.ExtensionManifest],
*,
extension_name: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
runtime_config: Optional[_RuntimeConfigOrJson] = None,
):
"""Creates a new Extension.
Args:
manifest (Union[dict[str, Any], ExtensionManifest]):
Required. The manifest for the Extension to be created.
extension_name (str):
Optional. A fully-qualified extension resource name or extension
ID such as "projects/123/locations/us-central1/extensions/456" or
"456" when project and location are initialized or passed. If
specifying the extension ID, it should be 4-63 characters, valid
characters are lowercase letters, numbers and hyphens ("-"),
and it should start with a number or a lower-case letter. If not
provided, Vertex AI will generate a value for this ID.
display_name (str):
Optional. The user-defined name of the Extension.
The name can be up to 128 characters long and can comprise any
UTF-8 character.
description (str):
Optional. The description of the Extension.
runtime_config (Union[dict[str, Any], RuntimeConfig]):
Optional. Runtime config controlling the runtime behavior of
this Extension. Defaults to None.
Returns:
Extension: The extension that was created.
"""
sdk_resource = cls.__new__(cls)
base.VertexAiResourceNounWithFutureManager.__init__(
sdk_resource,
resource_name=extension_name,
)
extension = types.Extension(
name=extension_name,
display_name=display_name or cls._generate_display_name(),
description=description,
manifest=_utils.to_proto(manifest, types.ExtensionManifest()),
)
if runtime_config:
extension.runtime_config = _utils.to_proto(
runtime_config,
types.RuntimeConfig(),
)
operation_future = sdk_resource.api_client.import_extension(
parent=initializer.global_config.common_location_path(),
extension=extension,
)
_LOGGER.log_create_with_lro(cls, operation_future)
created_extension = operation_future.result()
_LOGGER.log_create_complete(
cls,
created_extension,
cls._resource_noun,
module_name="vertexai.preview.extensions",
)
# We use `._get_gca_resource(...)` instead of `created_extension` to
# fully instantiate the attributes of the extension.
sdk_resource._gca_resource = sdk_resource._get_gca_resource(
resource_name=created_extension.name
)
sdk_resource.execution_api_client = initializer.global_config.create_client(
client_class=aip_utils.ExtensionExecutionClientWithOverride,
)
sdk_resource._api_spec = None
sdk_resource._operation_schemas = None
return sdk_resource
@property
def resource_name(self) -> str:
"""Full qualified resource name for the extension."""
return self._gca_resource.name
def api_spec(self) -> _utils.JsonDict:
"""Returns the (Open)API Spec of the extension."""
if self._api_spec is None:
self._api_spec = _load_api_spec(self._gca_resource.manifest.api_spec)
return self._api_spec
def operation_schemas(self) -> Sequence[_utils.JsonDict]:
"""Returns the (Open)API schemas for each operation of the extension."""
if self._operation_schemas is None:
self._operation_schemas = [
_utils.to_dict(op.function_declaration)
for op in self._gca_resource.extension_operations
]
return self._operation_schemas
def execute(
self,
operation_id: str,
operation_params: Optional[_StructOrJson] = None,
runtime_auth_config: Optional[_AuthConfigOrJson] = None,
) -> Union[_utils.JsonDict, str]:
"""Executes an operation of the extension with the specified params.
Args:
operation_id (str):
Required. The ID of the operation to be executed.
operation_params (Union[dict[str, Any], Struct]):
Optional. Parameters used for executing the operation. It should
be in a form of map with param name as the key and actual param
value as the value. E.g. if this operation requires a param
"name" to be set to "abc", you can set this to {"name": "abc"}.
Defaults to an empty dictionary.
runtime_auth_config (Union[dict[str, Any], AuthConfig]):
Optional. The Auth configuration to execute the operation.
Returns:
The result of executing the extension operation.
"""
request = types.ExecuteExtensionRequest(
name=self.resource_name,
operation_id=operation_id,
operation_params=operation_params,
)
if runtime_auth_config:
request.runtime_auth_config = _utils.to_proto(
runtime_auth_config,
types.AuthConfig(),
)
response = self.execution_api_client.execute_extension(request)
return _try_parse_execution_response(response)
def query(
self,
contents: _generative_models.ContentsType,
) -> "QueryExtensionResponse":
"""Queries an extension with the specified contents.
Args:
contents (ContentsType):
Required. The content of the current
conversation with the model.
For single-turn queries, this is a single
instance. For multi-turn queries, this is a
repeated field that contains conversation
history + latest request.
Returns:
The result of querying the extension.
Raises:
RuntimeError: If the response contains an error.
"""
request = types.QueryExtensionRequest(
name=self.resource_name,
contents=_generative_models._content_types_to_gapic_contents(contents),
)
response = self.execution_api_client.query_extension(request)
if response.failure_message:
raise RuntimeError(response.failure_message)
return QueryExtensionResponse._from_gapic(response)
@classmethod
def from_hub(
cls,
name: str,
*,
runtime_config: Optional[_RuntimeConfigOrJson] = None,
):
"""Creates a new Extension from the set of first party extensions.
Args:
name (str):
Required. The name of the extension in the hub to be created.
Supported values are "code_interpreter", "vertex_ai_search" and
"webpage_browser".
runtime_config (Union[dict[str, Any], RuntimeConfig]):
Optional. Runtime config controlling the runtime behavior of
the Extension. Defaults to None.
Returns:
Extension: The extension that was created.
Raises:
ValueError: If the `name` is not supported in the hub.
ValueError: If the `runtime_config` is specified but inconsistent
with the name (e.g. the name was "code_interpreter" but the
runtime_config was based on "vertex_ai_search_runtime_config").
"""
if runtime_config:
runtime_config = _utils.to_proto(
runtime_config,
types.RuntimeConfig(),
)
if name == "code_interpreter":
if runtime_config and not getattr(
runtime_config,
"code_interpreter_runtime_config",
None,
):
raise ValueError(
"code_interpreter_runtime_config is required for "
"code_interpreter extension"
)
elif name == "vertex_ai_search":
if not runtime_config:
raise ValueError(
"runtime_config is required for vertex_ai_search extension"
)
if runtime_config and not getattr(
runtime_config,
"vertex_ai_search_runtime_config",
None,
):
raise ValueError(
"vertex_ai_search_runtime_config is required for "
"vertex_ai_search extension"
)
elif name == "webpage_browser":
pass
else:
raise ValueError(f"Unsupported 1P extension name: {name}")
extension_info = _VERTEX_EXTENSION_HUB[name]
return cls.create(
display_name=extension_info["display_name"],
description=extension_info["description"],
manifest=extension_info["manifest"],
runtime_config=runtime_config,
)
class QueryExtensionResponse:
"""A class representing the response from querying an extension."""
def __init__(self, steps: List[_generative_models.Content]):
"""Initializes the QueryExtensionResponse with the given steps."""
self.steps = steps
@classmethod
def _from_gapic(
cls, response: types.QueryExtensionResponse
) -> "QueryExtensionResponse":
"""Creates a QueryExtensionResponse from a gapic response."""
return cls(
steps=[
_generative_models.Content(
parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
role=c.role,
)
for c in response.steps
]
)
def _try_parse_execution_response(
response: types.ExecuteExtensionResponse,
) -> Union[_utils.JsonDict, str]:
content: str = response.content
try:
content = json.loads(content)
except json.JSONDecodeError:
pass
return content
def _load_api_spec(api_spec) -> _utils.JsonDict:
"""Loads the (Open)API Spec of the extension and converts it to JSON."""
if api_spec.open_api_yaml:
yaml = aip_utils.yaml_utils._maybe_import_yaml()
return yaml.safe_load(api_spec.open_api_yaml)
elif api_spec.open_api_gcs_uri:
return aip_utils.yaml_utils.load_yaml(api_spec.open_api_gcs_uri)
return {}