structure saas with tools
This commit is contained in:
@@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils as aip_utils
|
||||
from google.cloud.aiplatform_v1beta1 import types
|
||||
from vertexai.generative_models import _generative_models
|
||||
from vertexai.reasoning_engines import _utils
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
_AuthConfigOrJson = Union[_utils.JsonDict, types.AuthConfig]
|
||||
_StructOrJson = Union[_utils.JsonDict, struct_pb2.Struct]
|
||||
_RuntimeConfigOrJson = Union[_utils.JsonDict, types.RuntimeConfig]
|
||||
|
||||
|
||||
_VERTEX_EXTENSION_HUB = {
|
||||
"code_interpreter": {
|
||||
"display_name": "Code Interpreter",
|
||||
"description": (
|
||||
"This extension generates and executes code in the specified language"
|
||||
),
|
||||
"manifest": {
|
||||
"name": "code_interpreter_tool",
|
||||
"description": "Google Code Interpreter Extension",
|
||||
"api_spec": {
|
||||
"open_api_gcs_uri": (
|
||||
"gs://vertex-extension-public/code_interpreter.yaml"
|
||||
),
|
||||
},
|
||||
"auth_config": {
|
||||
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
||||
"google_service_account_config": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"vertex_ai_search": {
|
||||
"display_name": "Vertex AI Search",
|
||||
"description": "This extension generates and executes search queries",
|
||||
"manifest": {
|
||||
"name": "vertex_ai_search",
|
||||
"description": "Vertex AI Search Extension",
|
||||
"api_spec": {
|
||||
"open_api_gcs_uri": (
|
||||
"gs://vertex-extension-public/vertex_ai_search.yaml"
|
||||
),
|
||||
},
|
||||
"auth_config": {
|
||||
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
||||
"google_service_account_config": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"webpage_browser": {
|
||||
"display_name": "Webpage Browser",
|
||||
"description": "This extension fetches the content of a webpage",
|
||||
"manifest": {
|
||||
"name": "webpage_browser",
|
||||
"description": "Vertex Webpage Browser Extension",
|
||||
"api_spec": {
|
||||
"open_api_gcs_uri": (
|
||||
"gs://vertex-extension-public/webpage_browser.yaml"
|
||||
),
|
||||
},
|
||||
"auth_config": {
|
||||
"auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
|
||||
"google_service_account_config": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Extension(base.VertexAiResourceNounWithFutureManager):
|
||||
"""Represents a Vertex AI Extension resource."""
|
||||
|
||||
client_class = aip_utils.ExtensionRegistryClientWithOverride
|
||||
_resource_noun = "extension"
|
||||
_getter_method = "get_extension"
|
||||
_list_method = "list_extensions"
|
||||
_delete_method = "delete_extension"
|
||||
_parse_resource_name_method = "parse_extension_path"
|
||||
_format_resource_name_method = "extension_path"
|
||||
|
||||
def __init__(self, extension_name: str):
|
||||
"""Retrieves an extension resource.
|
||||
|
||||
Args:
|
||||
extension_name (str):
|
||||
Required. A fully-qualified resource name or ID such as
|
||||
"projects/123/locations/us-central1/extensions/456" or
|
||||
"456" when project and location are initialized or passed.
|
||||
"""
|
||||
super().__init__(resource_name=extension_name)
|
||||
self.execution_api_client = initializer.global_config.create_client(
|
||||
client_class=aip_utils.ExtensionExecutionClientWithOverride,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(resource_name=extension_name)
|
||||
self._api_spec = None
|
||||
self._operation_schemas = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
manifest: Union[_utils.JsonDict, types.ExtensionManifest],
|
||||
*,
|
||||
extension_name: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
runtime_config: Optional[_RuntimeConfigOrJson] = None,
|
||||
):
|
||||
"""Creates a new Extension.
|
||||
|
||||
Args:
|
||||
manifest (Union[dict[str, Any], ExtensionManifest]):
|
||||
Required. The manifest for the Extension to be created.
|
||||
extension_name (str):
|
||||
Optional. A fully-qualified extension resource name or extension
|
||||
ID such as "projects/123/locations/us-central1/extensions/456" or
|
||||
"456" when project and location are initialized or passed. If
|
||||
specifying the extension ID, it should be 4-63 characters, valid
|
||||
characters are lowercase letters, numbers and hyphens ("-"),
|
||||
and it should start with a number or a lower-case letter. If not
|
||||
provided, Vertex AI will generate a value for this ID.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Extension.
|
||||
The name can be up to 128 characters long and can comprise any
|
||||
UTF-8 character.
|
||||
description (str):
|
||||
Optional. The description of the Extension.
|
||||
runtime_config (Union[dict[str, Any], RuntimeConfig]):
|
||||
Optional. Runtime config controlling the runtime behavior of
|
||||
this Extension. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Extension: The extension that was created.
|
||||
"""
|
||||
sdk_resource = cls.__new__(cls)
|
||||
base.VertexAiResourceNounWithFutureManager.__init__(
|
||||
sdk_resource,
|
||||
resource_name=extension_name,
|
||||
)
|
||||
extension = types.Extension(
|
||||
name=extension_name,
|
||||
display_name=display_name or cls._generate_display_name(),
|
||||
description=description,
|
||||
manifest=_utils.to_proto(manifest, types.ExtensionManifest()),
|
||||
)
|
||||
if runtime_config:
|
||||
extension.runtime_config = _utils.to_proto(
|
||||
runtime_config,
|
||||
types.RuntimeConfig(),
|
||||
)
|
||||
operation_future = sdk_resource.api_client.import_extension(
|
||||
parent=initializer.global_config.common_location_path(),
|
||||
extension=extension,
|
||||
)
|
||||
_LOGGER.log_create_with_lro(cls, operation_future)
|
||||
created_extension = operation_future.result()
|
||||
_LOGGER.log_create_complete(
|
||||
cls,
|
||||
created_extension,
|
||||
cls._resource_noun,
|
||||
module_name="vertexai.preview.extensions",
|
||||
)
|
||||
# We use `._get_gca_resource(...)` instead of `created_extension` to
|
||||
# fully instantiate the attributes of the extension.
|
||||
sdk_resource._gca_resource = sdk_resource._get_gca_resource(
|
||||
resource_name=created_extension.name
|
||||
)
|
||||
sdk_resource.execution_api_client = initializer.global_config.create_client(
|
||||
client_class=aip_utils.ExtensionExecutionClientWithOverride,
|
||||
)
|
||||
sdk_resource._api_spec = None
|
||||
sdk_resource._operation_schemas = None
|
||||
return sdk_resource
|
||||
|
||||
@property
|
||||
def resource_name(self) -> str:
|
||||
"""Full qualified resource name for the extension."""
|
||||
return self._gca_resource.name
|
||||
|
||||
def api_spec(self) -> _utils.JsonDict:
|
||||
"""Returns the (Open)API Spec of the extension."""
|
||||
if self._api_spec is None:
|
||||
self._api_spec = _load_api_spec(self._gca_resource.manifest.api_spec)
|
||||
return self._api_spec
|
||||
|
||||
def operation_schemas(self) -> Sequence[_utils.JsonDict]:
|
||||
"""Returns the (Open)API schemas for each operation of the extension."""
|
||||
if self._operation_schemas is None:
|
||||
self._operation_schemas = [
|
||||
_utils.to_dict(op.function_declaration)
|
||||
for op in self._gca_resource.extension_operations
|
||||
]
|
||||
return self._operation_schemas
|
||||
|
||||
def execute(
|
||||
self,
|
||||
operation_id: str,
|
||||
operation_params: Optional[_StructOrJson] = None,
|
||||
runtime_auth_config: Optional[_AuthConfigOrJson] = None,
|
||||
) -> Union[_utils.JsonDict, str]:
|
||||
"""Executes an operation of the extension with the specified params.
|
||||
|
||||
Args:
|
||||
operation_id (str):
|
||||
Required. The ID of the operation to be executed.
|
||||
operation_params (Union[dict[str, Any], Struct]):
|
||||
Optional. Parameters used for executing the operation. It should
|
||||
be in a form of map with param name as the key and actual param
|
||||
value as the value. E.g. if this operation requires a param
|
||||
"name" to be set to "abc", you can set this to {"name": "abc"}.
|
||||
Defaults to an empty dictionary.
|
||||
runtime_auth_config (Union[dict[str, Any], AuthConfig]):
|
||||
Optional. The Auth configuration to execute the operation.
|
||||
|
||||
Returns:
|
||||
The result of executing the extension operation.
|
||||
"""
|
||||
request = types.ExecuteExtensionRequest(
|
||||
name=self.resource_name,
|
||||
operation_id=operation_id,
|
||||
operation_params=operation_params,
|
||||
)
|
||||
if runtime_auth_config:
|
||||
request.runtime_auth_config = _utils.to_proto(
|
||||
runtime_auth_config,
|
||||
types.AuthConfig(),
|
||||
)
|
||||
response = self.execution_api_client.execute_extension(request)
|
||||
return _try_parse_execution_response(response)
|
||||
|
||||
def query(
|
||||
self,
|
||||
contents: _generative_models.ContentsType,
|
||||
) -> "QueryExtensionResponse":
|
||||
"""Queries an extension with the specified contents.
|
||||
|
||||
Args:
|
||||
contents (ContentsType):
|
||||
Required. The content of the current
|
||||
conversation with the model.
|
||||
For single-turn queries, this is a single
|
||||
instance. For multi-turn queries, this is a
|
||||
repeated field that contains conversation
|
||||
history + latest request.
|
||||
|
||||
Returns:
|
||||
The result of querying the extension.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the response contains an error.
|
||||
"""
|
||||
request = types.QueryExtensionRequest(
|
||||
name=self.resource_name,
|
||||
contents=_generative_models._content_types_to_gapic_contents(contents),
|
||||
)
|
||||
response = self.execution_api_client.query_extension(request)
|
||||
if response.failure_message:
|
||||
raise RuntimeError(response.failure_message)
|
||||
return QueryExtensionResponse._from_gapic(response)
|
||||
|
||||
@classmethod
|
||||
def from_hub(
|
||||
cls,
|
||||
name: str,
|
||||
*,
|
||||
runtime_config: Optional[_RuntimeConfigOrJson] = None,
|
||||
):
|
||||
"""Creates a new Extension from the set of first party extensions.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
Required. The name of the extension in the hub to be created.
|
||||
Supported values are "code_interpreter", "vertex_ai_search" and
|
||||
"webpage_browser".
|
||||
runtime_config (Union[dict[str, Any], RuntimeConfig]):
|
||||
Optional. Runtime config controlling the runtime behavior of
|
||||
the Extension. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Extension: The extension that was created.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `name` is not supported in the hub.
|
||||
ValueError: If the `runtime_config` is specified but inconsistent
|
||||
with the name (e.g. the name was "code_interpreter" but the
|
||||
runtime_config was based on "vertex_ai_search_runtime_config").
|
||||
"""
|
||||
if runtime_config:
|
||||
runtime_config = _utils.to_proto(
|
||||
runtime_config,
|
||||
types.RuntimeConfig(),
|
||||
)
|
||||
if name == "code_interpreter":
|
||||
if runtime_config and not getattr(
|
||||
runtime_config,
|
||||
"code_interpreter_runtime_config",
|
||||
None,
|
||||
):
|
||||
raise ValueError(
|
||||
"code_interpreter_runtime_config is required for "
|
||||
"code_interpreter extension"
|
||||
)
|
||||
elif name == "vertex_ai_search":
|
||||
if not runtime_config:
|
||||
raise ValueError(
|
||||
"runtime_config is required for vertex_ai_search extension"
|
||||
)
|
||||
if runtime_config and not getattr(
|
||||
runtime_config,
|
||||
"vertex_ai_search_runtime_config",
|
||||
None,
|
||||
):
|
||||
raise ValueError(
|
||||
"vertex_ai_search_runtime_config is required for "
|
||||
"vertex_ai_search extension"
|
||||
)
|
||||
elif name == "webpage_browser":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported 1P extension name: {name}")
|
||||
extension_info = _VERTEX_EXTENSION_HUB[name]
|
||||
return cls.create(
|
||||
display_name=extension_info["display_name"],
|
||||
description=extension_info["description"],
|
||||
manifest=extension_info["manifest"],
|
||||
runtime_config=runtime_config,
|
||||
)
|
||||
|
||||
|
||||
class QueryExtensionResponse:
|
||||
"""A class representing the response from querying an extension."""
|
||||
|
||||
def __init__(self, steps: List[_generative_models.Content]):
|
||||
"""Initializes the QueryExtensionResponse with the given steps."""
|
||||
self.steps = steps
|
||||
|
||||
@classmethod
|
||||
def _from_gapic(
|
||||
cls, response: types.QueryExtensionResponse
|
||||
) -> "QueryExtensionResponse":
|
||||
"""Creates a QueryExtensionResponse from a gapic response."""
|
||||
return cls(
|
||||
steps=[
|
||||
_generative_models.Content(
|
||||
parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
|
||||
role=c.role,
|
||||
)
|
||||
for c in response.steps
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _try_parse_execution_response(
|
||||
response: types.ExecuteExtensionResponse,
|
||||
) -> Union[_utils.JsonDict, str]:
|
||||
content: str = response.content
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return content
|
||||
|
||||
|
||||
def _load_api_spec(api_spec) -> _utils.JsonDict:
|
||||
"""Loads the (Open)API Spec of the extension and converts it to JSON."""
|
||||
if api_spec.open_api_yaml:
|
||||
yaml = aip_utils.yaml_utils._maybe_import_yaml()
|
||||
return yaml.safe_load(api_spec.open_api_yaml)
|
||||
elif api_spec.open_api_gcs_uri:
|
||||
return aip_utils.yaml_utils.load_yaml(api_spec.open_api_gcs_uri)
|
||||
return {}
|
||||
Reference in New Issue
Block a user