structure saas with tools
This commit is contained in:
107
.venv/lib/python3.10/site-packages/vertexai/rag/__init__.py
Normal file
107
.venv/lib/python3.10/site-packages/vertexai/rag/__init__.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vertexai.rag.rag_data import (
|
||||
create_corpus,
|
||||
update_corpus,
|
||||
list_corpora,
|
||||
get_corpus,
|
||||
delete_corpus,
|
||||
upload_file,
|
||||
import_files,
|
||||
import_files_async,
|
||||
get_file,
|
||||
list_files,
|
||||
delete_file,
|
||||
)
|
||||
|
||||
from vertexai.rag.rag_retrieval import (
|
||||
retrieval_query,
|
||||
)
|
||||
|
||||
from vertexai.rag.rag_store import (
|
||||
Retrieval,
|
||||
VertexRagStore,
|
||||
)
|
||||
from vertexai.rag.utils.resources import (
|
||||
ChunkingConfig,
|
||||
Filter,
|
||||
JiraQuery,
|
||||
JiraSource,
|
||||
LayoutParserConfig,
|
||||
LlmRanker,
|
||||
Pinecone,
|
||||
RagCorpus,
|
||||
RagEmbeddingModelConfig,
|
||||
RagFile,
|
||||
RagManagedDb,
|
||||
RagResource,
|
||||
RagRetrievalConfig,
|
||||
RagVectorDbConfig,
|
||||
Ranking,
|
||||
RankService,
|
||||
SharePointSource,
|
||||
SharePointSources,
|
||||
SlackChannel,
|
||||
SlackChannelsSource,
|
||||
TransformationConfig,
|
||||
VertexAiSearchConfig,
|
||||
VertexPredictionEndpoint,
|
||||
VertexVectorSearch,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"ChunkingConfig",
|
||||
"Filter",
|
||||
"JiraQuery",
|
||||
"JiraSource",
|
||||
"LayoutParserConfig",
|
||||
"LlmRanker",
|
||||
"Pinecone",
|
||||
"RagCorpus",
|
||||
"RagEmbeddingModelConfig",
|
||||
"RagFile",
|
||||
"RagManagedDb",
|
||||
"RagResource",
|
||||
"RagRetrievalConfig",
|
||||
"RagVectorDbConfig",
|
||||
"Ranking",
|
||||
"RankService",
|
||||
"Retrieval",
|
||||
"SharePointSource",
|
||||
"SharePointSources",
|
||||
"SlackChannel",
|
||||
"SlackChannelsSource",
|
||||
"TransformationConfig",
|
||||
"VertexAiSearchConfig",
|
||||
"VertexRagStore",
|
||||
"VertexPredictionEndpoint",
|
||||
"VertexVectorSearch",
|
||||
"create_corpus",
|
||||
"delete_corpus",
|
||||
"delete_file",
|
||||
"get_corpus",
|
||||
"get_file",
|
||||
"import_files",
|
||||
"import_files_async",
|
||||
"list_corpora",
|
||||
"list_files",
|
||||
"retrieval_query",
|
||||
"upload_file",
|
||||
"update_corpus",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
870
.venv/lib/python3.10/site-packages/vertexai/rag/rag_data.py
Normal file
870
.venv/lib/python3.10/site-packages/vertexai/rag/rag_data.py
Normal file
@@ -0,0 +1,870 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""RAG data management SDK."""
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
from google import auth
|
||||
from google.api_core import operation_async
|
||||
from google.auth.transport import requests as google_auth_requests
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform_v1 import (
|
||||
CreateRagCorpusRequest,
|
||||
DeleteRagCorpusRequest,
|
||||
DeleteRagFileRequest,
|
||||
GetRagCorpusRequest,
|
||||
GetRagFileRequest,
|
||||
ImportRagFilesResponse,
|
||||
ListRagCorporaRequest,
|
||||
ListRagFilesRequest,
|
||||
RagCorpus as GapicRagCorpus,
|
||||
UpdateRagCorpusRequest,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.vertex_rag_data_service.pagers import (
|
||||
ListRagCorporaPager,
|
||||
ListRagFilesPager,
|
||||
)
|
||||
from vertexai.rag.utils import (
|
||||
_gapic_utils,
|
||||
)
|
||||
from vertexai.rag.utils.resources import (
|
||||
JiraSource,
|
||||
LayoutParserConfig,
|
||||
RagCorpus,
|
||||
RagFile,
|
||||
RagVectorDbConfig,
|
||||
SharePointSources,
|
||||
SlackChannelsSource,
|
||||
VertexAiSearchConfig,
|
||||
TransformationConfig,
|
||||
)
|
||||
|
||||
|
||||
def create_corpus(
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
|
||||
backend_config: Optional[
|
||||
Union[
|
||||
RagVectorDbConfig,
|
||||
None,
|
||||
]
|
||||
] = None,
|
||||
) -> RagCorpus:
|
||||
"""Creates a new RagCorpus resource.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
rag_corpus = rag.create_corpus(
|
||||
display_name="my-corpus-1",
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
display_name: If not provided, SDK will create one. The display name of
|
||||
the RagCorpus. The name can be up to 128 characters long and can
|
||||
consist of any UTF-8 characters.
|
||||
description: The description of the RagCorpus.
|
||||
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
|
||||
Note: backend_config cannot be set if vertex_ai_search_config is
|
||||
specified.
|
||||
backend_config: The backend config of the RagCorpus, specifying a
|
||||
data store and/or embedding model.
|
||||
Returns:
|
||||
RagCorpus.
|
||||
Raises:
|
||||
RuntimeError: Failed in RagCorpus creation due to exception.
|
||||
RuntimeError: Failed in RagCorpus creation due to operation error.
|
||||
"""
|
||||
if vertex_ai_search_config and backend_config:
|
||||
raise ValueError(
|
||||
"Only one of vertex_ai_search_config or backend_config can be set."
|
||||
)
|
||||
|
||||
if not display_name:
|
||||
display_name = "vertex-" + utils.timestamped_unique_name()
|
||||
parent = initializer.global_config.common_location_path(project=None, location=None)
|
||||
|
||||
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
|
||||
|
||||
if backend_config:
|
||||
_gapic_utils.set_backend_config(
|
||||
backend_config=backend_config,
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
elif vertex_ai_search_config:
|
||||
_gapic_utils.set_vertex_ai_search_config(
|
||||
vertex_ai_search_config=vertex_ai_search_config,
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
|
||||
request = CreateRagCorpusRequest(
|
||||
parent=parent,
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
|
||||
try:
|
||||
response = client.create_rag_corpus(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in RagCorpus creation due to: ", e) from e
|
||||
return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=600))
|
||||
|
||||
|
||||
def update_corpus(
|
||||
corpus_name: str,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
|
||||
backend_config: Optional[
|
||||
Union[
|
||||
RagVectorDbConfig,
|
||||
None,
|
||||
]
|
||||
] = None,
|
||||
) -> RagCorpus:
|
||||
"""Updates a RagCorpus resource. It is intended to update 3rd party vector
|
||||
DBs (Vector Search, Vertex AI Feature Store, Weaviate, Pinecone) but not
|
||||
Vertex RagManagedDb.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
rag_corpus = rag.update_corpus(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
display_name="my-corpus-1",
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
corpus_name: The name of the RagCorpus resource to update. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or
|
||||
``{rag_corpus}``.
|
||||
display_name: If not provided, the display name will not be updated. The
|
||||
display name of the RagCorpus. The name can be up to 128 characters long
|
||||
and can consist of any UTF-8 characters.
|
||||
description: The description of the RagCorpus. If not provided, the
|
||||
description will not be updated.
|
||||
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
|
||||
If not provided, the Vertex AI Search config will not be updated.
|
||||
Note: backend_config cannot be set if vertex_ai_search_config is
|
||||
specified.
|
||||
backend_config: The backend config of the RagCorpus, specifying a
|
||||
data store and/or embedding model.
|
||||
|
||||
Returns:
|
||||
RagCorpus.
|
||||
Raises:
|
||||
RuntimeError: Failed in RagCorpus update due to exception.
|
||||
RuntimeError: Failed in RagCorpus update due to operation error.
|
||||
"""
|
||||
if vertex_ai_search_config and backend_config:
|
||||
raise ValueError(
|
||||
"Only one of vertex_ai_search_config or backend_config can be set."
|
||||
)
|
||||
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
if display_name and description:
|
||||
rag_corpus = GapicRagCorpus(
|
||||
name=corpus_name, display_name=display_name, description=description
|
||||
)
|
||||
elif display_name:
|
||||
rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name)
|
||||
elif description:
|
||||
rag_corpus = GapicRagCorpus(name=corpus_name, description=description)
|
||||
else:
|
||||
rag_corpus = GapicRagCorpus(name=corpus_name)
|
||||
|
||||
if backend_config:
|
||||
_gapic_utils.set_backend_config(
|
||||
backend_config=backend_config,
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
|
||||
if vertex_ai_search_config:
|
||||
_gapic_utils.set_vertex_ai_search_config(
|
||||
vertex_ai_search_config=vertex_ai_search_config,
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
|
||||
request = UpdateRagCorpusRequest(
|
||||
rag_corpus=rag_corpus,
|
||||
)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
|
||||
try:
|
||||
response = client.update_rag_corpus(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in RagCorpus update due to: ", e) from e
|
||||
return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config(
|
||||
response.result(timeout=600)
|
||||
)
|
||||
|
||||
|
||||
def get_corpus(name: str) -> RagCorpus:
|
||||
"""
|
||||
Get an existing RagCorpus.
|
||||
|
||||
Args:
|
||||
name: An existing RagCorpus resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
Returns:
|
||||
RagCorpus.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(name)
|
||||
request = GetRagCorpusRequest(name=corpus_name)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
response = client.get_rag_corpus(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in getting the RagCorpus due to: ", e) from e
|
||||
return _gapic_utils.convert_gapic_to_rag_corpus(response)
|
||||
|
||||
|
||||
def list_corpora(
|
||||
page_size: Optional[int] = None, page_token: Optional[str] = None
|
||||
) -> ListRagCorporaPager:
|
||||
"""
|
||||
List all RagCorpora in the same project and location.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
# List all corpora.
|
||||
rag_corpora = list(rag.list_corpora())
|
||||
|
||||
# Alternatively, return a ListRagCorporaPager.
|
||||
pager_1 = rag.list_corpora(page_size=10)
|
||||
# Then get the next page, use the generated next_page_token from the last pager.
|
||||
pager_2 = rag.list_corpora(page_size=10, page_token=pager_1.next_page_token)
|
||||
|
||||
```
|
||||
Args:
|
||||
page_size: The standard list page size. Leaving out the page_size
|
||||
causes all of the results to be returned.
|
||||
page_token: The standard list page token.
|
||||
|
||||
Returns:
|
||||
ListRagCorporaPager.
|
||||
"""
|
||||
parent = initializer.global_config.common_location_path(project=None, location=None)
|
||||
request = ListRagCorporaRequest(
|
||||
parent=parent,
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
pager = client.list_rag_corpora(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in listing the RagCorpora due to: ", e) from e
|
||||
|
||||
return pager
|
||||
|
||||
|
||||
def delete_corpus(name: str) -> None:
|
||||
"""
|
||||
Delete an existing RagCorpus.
|
||||
|
||||
Args:
|
||||
name: An existing RagCorpus resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(name)
|
||||
request = DeleteRagCorpusRequest(name=corpus_name)
|
||||
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
client.delete_rag_corpus(request=request)
|
||||
print("Successfully deleted the RagCorpus.")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in RagCorpus deletion due to: ", e) from e
|
||||
return None
|
||||
|
||||
|
||||
def upload_file(
|
||||
corpus_name: str,
|
||||
path: Union[str, Sequence[str]],
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
transformation_config: Optional[TransformationConfig] = None,
|
||||
) -> RagFile:
|
||||
"""
|
||||
Synchronous file upload to an existing RagCorpus.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
// Optional.
|
||||
transformation_config = TransformationConfig(
|
||||
chunking_config=ChunkingConfig(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=200,
|
||||
),
|
||||
)
|
||||
|
||||
rag_file = rag.upload_file(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
display_name="my_file.txt",
|
||||
path="usr/home/my_file.txt",
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
corpus_name: The name of the RagCorpus resource into which to upload the file.
|
||||
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
path: A local file path. For example,
|
||||
"usr/home/my_file.txt".
|
||||
display_name: The display name of the data file.
|
||||
description: The description of the RagFile.
|
||||
transformation_config: The config for transforming the RagFile, like chunking.
|
||||
|
||||
Returns:
|
||||
RagFile.
|
||||
Raises:
|
||||
RuntimeError: Failed in RagFile upload.
|
||||
ValueError: RagCorpus is not found.
|
||||
RuntimeError: Failed in indexing the RagFile.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
location = initializer.global_config.location
|
||||
# GAPIC doesn't expose a path (scotty). Use requests API instead
|
||||
if display_name is None:
|
||||
display_name = "vertex-" + utils.timestamped_unique_name()
|
||||
headers = {"X-Goog-Upload-Protocol": "multipart"}
|
||||
if not initializer.global_config.api_endpoint:
|
||||
request_endpoint = "{}-{}".format(
|
||||
location, aiplatform.constants.base.API_BASE_PATH
|
||||
)
|
||||
else:
|
||||
request_endpoint = initializer.global_config.api_endpoint
|
||||
upload_request_uri = "https://{}/upload/v1/{}/ragFiles:upload".format(
|
||||
request_endpoint,
|
||||
corpus_name,
|
||||
)
|
||||
js_rag_file = {"rag_file": {"display_name": display_name}}
|
||||
|
||||
if description:
|
||||
js_rag_file["rag_file"]["description"] = description
|
||||
|
||||
if transformation_config and transformation_config.chunking_config:
|
||||
chunk_size = transformation_config.chunking_config.chunk_size
|
||||
chunk_overlap = transformation_config.chunking_config.chunk_overlap
|
||||
js_rag_file["upload_rag_file_config"] = {
|
||||
"rag_file_transformation_config": {
|
||||
"rag_file_chunking_config": {
|
||||
"fixed_length_chunking": {
|
||||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
files = {
|
||||
"metadata": (None, str(js_rag_file)),
|
||||
"file": open(path, "rb"),
|
||||
}
|
||||
credentials, _ = auth.default()
|
||||
authorized_session = google_auth_requests.AuthorizedSession(credentials=credentials)
|
||||
try:
|
||||
response = authorized_session.post(
|
||||
url=upload_request_uri,
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in uploading the RagFile due to: ", e) from e
|
||||
|
||||
if response.status_code == 404:
|
||||
raise ValueError(
|
||||
"RagCorpus '%s' is not found: %s", corpus_name, upload_request_uri
|
||||
)
|
||||
if response.json().get("error"):
|
||||
raise RuntimeError(
|
||||
"Failed in indexing the RagFile due to: ", response.json().get("error")
|
||||
)
|
||||
return _gapic_utils.convert_json_to_rag_file(response.json())
|
||||
|
||||
|
||||
def import_files(
|
||||
corpus_name: str,
|
||||
paths: Optional[Sequence[str]] = None,
|
||||
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
|
||||
transformation_config: Optional[TransformationConfig] = None,
|
||||
timeout: int = 600,
|
||||
max_embedding_requests_per_min: int = 1000,
|
||||
import_result_sink: Optional[str] = None,
|
||||
partial_failures_sink: Optional[str] = None,
|
||||
parser: Optional[LayoutParserConfig] = None,
|
||||
) -> ImportRagFilesResponse:
|
||||
"""
|
||||
Import files to an existing RagCorpus, wait until completion.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
# Google Drive example
|
||||
paths = [
|
||||
"https://drive.google.com/file/d/123",
|
||||
"https://drive.google.com/drive/folders/456"
|
||||
]
|
||||
# Google Cloud Storage example
|
||||
paths = ["gs://my_bucket/my_files_dir", ...]
|
||||
|
||||
transformation_config = TransformationConfig(
|
||||
chunking_config=ChunkingConfig(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=200,
|
||||
),
|
||||
)
|
||||
|
||||
response = rag.import_files(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
paths=paths,
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
|
||||
# Slack example
|
||||
start_time = timestamp_pb2.Timestamp()
|
||||
start_time.FromJsonString('2020-12-31T21:33:44Z')
|
||||
end_time = timestamp_pb2.Timestamp()
|
||||
end_time.GetCurrentTime()
|
||||
source = rag.SlackChannelsSource(
|
||||
channels = [
|
||||
SlackChannel("channel1", "api_key1"),
|
||||
SlackChannel("channel2", "api_key2", start_time, end_time)
|
||||
],
|
||||
)
|
||||
# Jira Example
|
||||
jira_query = rag.JiraQuery(
|
||||
email="xxx@yyy.com",
|
||||
jira_projects=["project1", "project2"],
|
||||
custom_queries=["query1", "query2"],
|
||||
api_key="api_key",
|
||||
server_uri="server.atlassian.net"
|
||||
)
|
||||
source = rag.JiraSource(
|
||||
queries=[jira_query],
|
||||
)
|
||||
|
||||
response = rag.import_files(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
source=source,
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
|
||||
# SharePoint Example.
|
||||
sharepoint_query = rag.SharePointSource(
|
||||
sharepoint_folder_path="https://my-sharepoint-site.com/my-folder",
|
||||
sharepoint_site_name="my-sharepoint-site.com",
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
tenant_id="my-tenant-id",
|
||||
drive_id="my-drive-id",
|
||||
)
|
||||
source = rag.SharePointSources(
|
||||
share_point_sources=[sharepoint_query],
|
||||
)
|
||||
|
||||
# Return the number of imported RagFiles after completion.
|
||||
print(response.imported_rag_files_count)
|
||||
|
||||
# Document AI Layout Parser example.
|
||||
parser = LayoutParserConfig(
|
||||
processor_name="projects/my-project/locations/us-central1/processors/my-processor-id",
|
||||
max_parsing_requests_per_min=120,
|
||||
)
|
||||
response = rag.import_files(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
paths=paths,
|
||||
parser=parser,
|
||||
)
|
||||
|
||||
```
|
||||
Args:
|
||||
corpus_name: The name of the RagCorpus resource into which to import files.
|
||||
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
paths: A list of uris. Eligible uris will be Google Cloud Storage
|
||||
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
|
||||
(https://drive.google.com/file/... or folder
|
||||
"https://drive.google.com/corp/drive/folders/...").
|
||||
source: The source of the Slack or Jira import.
|
||||
Must be either a SlackChannelsSource or JiraSource.
|
||||
transformation_config: The config for transforming the imported
|
||||
RagFiles.
|
||||
max_embedding_requests_per_min:
|
||||
Optional. The max number of queries per
|
||||
minute that this job is allowed to make to the
|
||||
embedding model specified on the corpus. This
|
||||
value is specific to this job and not shared
|
||||
across other import jobs. Consult the Quotas
|
||||
page on the project to set an appropriate value
|
||||
here. If unspecified, a default value of 1,000
|
||||
QPM would be used.
|
||||
timeout: Default is 600 seconds.
|
||||
import_result_sink: Either a GCS path to store import results or a
|
||||
BigQuery table to store import results. The format is
|
||||
"gs://my-bucket/my/object.ndjson" for GCS or
|
||||
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
|
||||
object cannot be used. However, the BigQuery table may or may not
|
||||
exist - if it does not exist, it will be created. If it does exist,
|
||||
the schema will be checked and the import results will be appended
|
||||
to the table.
|
||||
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
|
||||
Either a GCS path to store partial failures or a BigQuery table to
|
||||
store partial failures. The format is
|
||||
"gs://my-bucket/my/object.ndjson" for GCS or
|
||||
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
|
||||
object cannot be used. However, the BigQuery table may or may not
|
||||
exist - if it does not exist, it will be created. If it does exist,
|
||||
the schema will be checked and the partial failures will be appended
|
||||
to the table.
|
||||
parser: Document parser to use. Should be either None (default parser),
|
||||
or a LayoutParserConfig (to parse documents using a Document AI
|
||||
Layout Parser processor).
|
||||
Returns:
|
||||
ImportRagFilesResponse.
|
||||
"""
|
||||
if source is not None and paths is not None:
|
||||
raise ValueError("Only one of source or paths must be passed in at a time")
|
||||
if source is None and paths is None:
|
||||
raise ValueError("One of source or paths must be passed in")
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
request = _gapic_utils.prepare_import_files_request(
|
||||
corpus_name=corpus_name,
|
||||
paths=paths,
|
||||
source=source,
|
||||
transformation_config=transformation_config,
|
||||
max_embedding_requests_per_min=max_embedding_requests_per_min,
|
||||
import_result_sink=import_result_sink,
|
||||
partial_failures_sink=partial_failures_sink,
|
||||
parser=parser,
|
||||
)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
response = client.import_rag_files(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e
|
||||
|
||||
return response.result(timeout=timeout)
|
||||
|
||||
|
||||
async def import_files_async(
|
||||
corpus_name: str,
|
||||
paths: Optional[Sequence[str]] = None,
|
||||
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
|
||||
transformation_config: Optional[TransformationConfig] = None,
|
||||
max_embedding_requests_per_min: int = 1000,
|
||||
import_result_sink: Optional[str] = None,
|
||||
partial_failures_sink: Optional[str] = None,
|
||||
parser: Optional[LayoutParserConfig] = None,
|
||||
) -> operation_async.AsyncOperation:
|
||||
"""
|
||||
Import files to an existing RagCorpus asynchronously.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
import vertexai
|
||||
from vertexai import rag
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
# Google Drive example
|
||||
paths = [
|
||||
"https://drive.google.com/file/d/123",
|
||||
"https://drive.google.com/drive/folders/456"
|
||||
]
|
||||
# Google Cloud Storage example
|
||||
paths = ["gs://my_bucket/my_files_dir", ...]
|
||||
|
||||
transformation_config = TransformationConfig(
|
||||
chunking_config=ChunkingConfig(
|
||||
chunk_size=1024,
|
||||
chunk_overlap=200,
|
||||
),
|
||||
)
|
||||
|
||||
response = await rag.import_files_async(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
paths=paths,
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
|
||||
# Slack example
|
||||
start_time = timestamp_pb2.Timestamp()
|
||||
start_time.FromJsonString('2020-12-31T21:33:44Z')
|
||||
end_time = timestamp_pb2.Timestamp()
|
||||
end_time.GetCurrentTime()
|
||||
source = rag.SlackChannelsSource(
|
||||
channels = [
|
||||
SlackChannel("channel1", "api_key1"),
|
||||
SlackChannel("channel2", "api_key2", start_time, end_time)
|
||||
],
|
||||
)
|
||||
# Jira Example
|
||||
jira_query = rag.JiraQuery(
|
||||
email="xxx@yyy.com",
|
||||
jira_projects=["project1", "project2"],
|
||||
custom_queries=["query1", "query2"],
|
||||
api_key="api_key",
|
||||
server_uri="server.atlassian.net"
|
||||
)
|
||||
source = rag.JiraSource(
|
||||
queries=[jira_query],
|
||||
)
|
||||
|
||||
response = await rag.import_files_async(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
source=source,
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
|
||||
# SharePoint Example.
|
||||
sharepoint_query = rag.SharePointSource(
|
||||
sharepoint_folder_path="https://my-sharepoint-site.com/my-folder",
|
||||
sharepoint_site_name="my-sharepoint-site.com",
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
tenant_id="my-tenant-id",
|
||||
drive_id="my-drive-id",
|
||||
)
|
||||
source = rag.SharePointSources(
|
||||
share_point_sources=[sharepoint_query],
|
||||
)
|
||||
|
||||
# Document AI Layout Parser example.
|
||||
parser = LayoutParserConfig(
|
||||
processor_name="projects/my-project/locations/us-central1/processors/my-processor-id",
|
||||
max_parsing_requests_per_min=120,
|
||||
)
|
||||
response = rag.import_files_async(
|
||||
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
|
||||
paths=paths,
|
||||
parser=parser,
|
||||
)
|
||||
|
||||
# Get the result.
|
||||
await response.result()
|
||||
|
||||
```
|
||||
Args:
|
||||
corpus_name: The name of the RagCorpus resource into which to import files.
|
||||
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
paths: A list of uris. Eligible uris will be Google Cloud Storage
|
||||
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
|
||||
(https://drive.google.com/file/... or folder
|
||||
"https://drive.google.com/corp/drive/folders/...").
|
||||
source: The source of the Slack or Jira import.
|
||||
Must be either a SlackChannelsSource or JiraSource.
|
||||
transformation_config: The config for transforming the imported
|
||||
RagFiles.
|
||||
max_embedding_requests_per_min:
|
||||
Optional. The max number of queries per
|
||||
minute that this job is allowed to make to the
|
||||
embedding model specified on the corpus. This
|
||||
value is specific to this job and not shared
|
||||
across other import jobs. Consult the Quotas
|
||||
page on the project to set an appropriate value
|
||||
here. If unspecified, a default value of 1,000
|
||||
QPM would be used.
|
||||
import_result_sink: Either a GCS path to store import results or a
|
||||
BigQuery table to store import results. The format is
|
||||
"gs://my-bucket/my/object.ndjson" for GCS or
|
||||
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
|
||||
object cannot be used. However, the BigQuery table may or may not
|
||||
exist - if it does not exist, it will be created. If it does exist,
|
||||
the schema will be checked and the import results will be appended
|
||||
to the table.
|
||||
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
|
||||
Either a GCS path to store partial failures or a BigQuery table to
|
||||
store partial failures. The format is
|
||||
"gs://my-bucket/my/object.ndjson" for GCS or
|
||||
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
|
||||
object cannot be used. However, the BigQuery table may or may not
|
||||
exist - if it does not exist, it will be created. If it does exist,
|
||||
the schema will be checked and the partial failures will be appended
|
||||
to the table.
|
||||
parser: Document parser to use. Should be either None (default parser),
|
||||
or a LayoutParserConfig (to parse documents using a Document AI
|
||||
Layout Parser processor).
|
||||
Returns:
|
||||
operation_async.AsyncOperation.
|
||||
"""
|
||||
if source is not None and paths is not None:
|
||||
raise ValueError("Only one of source or paths must be passed in at a time")
|
||||
if source is None and paths is None:
|
||||
raise ValueError("One of source or paths must be passed in")
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
request = _gapic_utils.prepare_import_files_request(
|
||||
corpus_name=corpus_name,
|
||||
paths=paths,
|
||||
source=source,
|
||||
transformation_config=transformation_config,
|
||||
max_embedding_requests_per_min=max_embedding_requests_per_min,
|
||||
import_result_sink=import_result_sink,
|
||||
partial_failures_sink=partial_failures_sink,
|
||||
parser=parser,
|
||||
)
|
||||
async_client = _gapic_utils.create_rag_data_service_async_client()
|
||||
try:
|
||||
response = await async_client.import_rag_files(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e
|
||||
return response
|
||||
|
||||
|
||||
def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile:
|
||||
"""
|
||||
Get an existing RagFile.
|
||||
|
||||
Args:
|
||||
name: Either a full RagFile resource name must be provided, or a RagCorpus
|
||||
name and a RagFile name must be provided. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
|
||||
or ``{rag_file}``.
|
||||
corpus_name: If `name` is not a full resource name, an existing RagCorpus
|
||||
name must be provided. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
Returns:
|
||||
RagFile.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
name = _gapic_utils.get_file_name(name, corpus_name)
|
||||
request = GetRagFileRequest(name=name)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
response = client.get_rag_file(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in getting the RagFile due to: ", e) from e
|
||||
return _gapic_utils.convert_gapic_to_rag_file(response)
|
||||
|
||||
|
||||
def list_files(
|
||||
corpus_name: str, page_size: Optional[int] = None, page_token: Optional[str] = None
|
||||
) -> ListRagFilesPager:
|
||||
"""
|
||||
List all RagFiles in an existing RagCorpus.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
# List all corpora.
|
||||
rag_corpora = list(rag.list_corpora())
|
||||
|
||||
# List all files of the first corpus.
|
||||
rag_files = list(rag.list_files(corpus_name=rag_corpora[0].name))
|
||||
|
||||
# Alternatively, return a ListRagFilesPager.
|
||||
pager_1 = rag.list_files(
|
||||
corpus_name=rag_corpora[0].name,
|
||||
page_size=10
|
||||
)
|
||||
# Then get the next page, use the generated next_page_token from the last pager.
|
||||
pager_2 = rag.list_files(
|
||||
corpus_name=rag_corpora[0].name,
|
||||
page_size=10,
|
||||
page_token=pager_1.next_page_token
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Args:
|
||||
corpus_name: An existing RagCorpus name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
page_size: The standard list page size. Leaving out the page_size
|
||||
causes all of the results to be returned.
|
||||
page_token: The standard list page token.
|
||||
Returns:
|
||||
ListRagFilesPager.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
request = ListRagFilesRequest(
|
||||
parent=corpus_name,
|
||||
page_size=page_size,
|
||||
page_token=page_token,
|
||||
)
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
pager = client.list_rag_files(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in listing the RagFiles due to: ", e) from e
|
||||
|
||||
return pager
|
||||
|
||||
|
||||
def delete_file(name: str, corpus_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Delete RagFile from an existing RagCorpus.
|
||||
|
||||
Args:
|
||||
name: Either a full RagFile resource name must be provided, or a RagCorpus
|
||||
name and a RagFile name must be provided. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
|
||||
or ``{rag_file}``.
|
||||
corpus_name: If `name` is not a full resource name, an existing RagCorpus
|
||||
name must be provided. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
|
||||
or ``{rag_corpus}``.
|
||||
"""
|
||||
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
|
||||
name = _gapic_utils.get_file_name(name, corpus_name)
|
||||
request = DeleteRagFileRequest(name=name)
|
||||
|
||||
client = _gapic_utils.create_rag_data_service_client()
|
||||
try:
|
||||
client.delete_rag_file(request=request)
|
||||
print("Successfully deleted the RagFile.")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in RagFile deletion due to: ", e) from e
|
||||
return None
|
||||
167
.venv/lib/python3.10/site-packages/vertexai/rag/rag_retrieval.py
Normal file
167
.venv/lib/python3.10/site-packages/vertexai/rag/rag_retrieval.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Retrieval query to get relevant contexts."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from google.cloud import aiplatform_v1
|
||||
from google.cloud.aiplatform import initializer
|
||||
from vertexai.rag.utils import _gapic_utils
|
||||
from vertexai.rag.utils import resources
|
||||
|
||||
|
||||
def retrieval_query(
|
||||
text: str,
|
||||
rag_resources: Optional[List[resources.RagResource]] = None,
|
||||
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
|
||||
) -> aiplatform_v1.RetrieveContextsResponse:
|
||||
"""Retrieve top k relevant docs/chunks.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
config = vertexai.rag.rag_retrieval_config(
|
||||
top_k=2,
|
||||
filter=vertexai.rag.rag_retrieval_config.filter(
|
||||
vector_distance_threshold=0.5
|
||||
),
|
||||
ranking=vertex.rag.Ranking(
|
||||
llm_ranker=vertexai.rag.LlmRanker(
|
||||
model_name="gemini-1.5-flash-002"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results = vertexai.rag.retrieval_query(
|
||||
text="Why is the sky blue?",
|
||||
rag_resources=[vertexai.rag.RagResource(
|
||||
rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1",
|
||||
rag_file_ids=["rag-file-1", "rag-file-2", ...],
|
||||
)],
|
||||
rag_retrieval_config=config,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
text: The query in text format to get relevant contexts.
|
||||
rag_resources: A list of RagResource. It can be used to specify corpus
|
||||
only or ragfiles. Currently only support one corpus or multiple files
|
||||
from one corpus. In the future we may open up multiple corpora support.
|
||||
rag_retrieval_config: Optional. The config containing the retrieval
|
||||
parameters, including similarity_top_k and vector_distance_threshold
|
||||
|
||||
Returns:
|
||||
RetrieveContextsResonse.
|
||||
"""
|
||||
parent = initializer.global_config.common_location_path()
|
||||
|
||||
client = _gapic_utils.create_rag_service_client()
|
||||
|
||||
if rag_resources:
|
||||
if len(rag_resources) > 1:
|
||||
raise ValueError("Currently only support 1 RagResource.")
|
||||
name = rag_resources[0].rag_corpus
|
||||
else:
|
||||
raise ValueError("rag_resources must be specified.")
|
||||
|
||||
data_client = _gapic_utils.create_rag_data_service_client()
|
||||
if data_client.parse_rag_corpus_path(name):
|
||||
rag_corpus_name = name
|
||||
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
|
||||
rag_corpus_name = parent + "/ragCorpora/" + name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid RagCorpus name: {name}. Proper format should be:"
|
||||
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
|
||||
)
|
||||
|
||||
if rag_resources:
|
||||
gapic_rag_resource = (
|
||||
aiplatform_v1.RetrieveContextsRequest.VertexRagStore.RagResource(
|
||||
rag_corpus=rag_corpus_name,
|
||||
rag_file_ids=rag_resources[0].rag_file_ids,
|
||||
)
|
||||
)
|
||||
vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore(
|
||||
rag_resources=[gapic_rag_resource],
|
||||
)
|
||||
else:
|
||||
vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore(
|
||||
rag_corpora=[rag_corpus_name],
|
||||
)
|
||||
|
||||
# If rag_retrieval_config is not specified, set it to default values.
|
||||
if not rag_retrieval_config:
|
||||
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
|
||||
else:
|
||||
# If rag_retrieval_config is specified, check for missing parameters.
|
||||
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
|
||||
api_retrieval_config.top_k = rag_retrieval_config.top_k
|
||||
# Set vector_distance_threshold to config value if specified
|
||||
if rag_retrieval_config.filter:
|
||||
# Check if both vector_distance_threshold and vector_similarity_threshold
|
||||
# are specified.
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of vector_distance_threshold or"
|
||||
" vector_similarity_threshold can be specified at a time"
|
||||
" in rag_retrieval_config."
|
||||
)
|
||||
api_retrieval_config.filter.vector_distance_threshold = (
|
||||
rag_retrieval_config.filter.vector_distance_threshold
|
||||
)
|
||||
api_retrieval_config.filter.vector_similarity_threshold = (
|
||||
rag_retrieval_config.filter.vector_similarity_threshold
|
||||
)
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
and rag_retrieval_config.ranking.llm_ranker
|
||||
):
|
||||
raise ValueError("Only one of rank_service and llm_ranker can be set.")
|
||||
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
|
||||
api_retrieval_config.ranking.rank_service.model_name = (
|
||||
rag_retrieval_config.ranking.rank_service.model_name
|
||||
)
|
||||
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
|
||||
api_retrieval_config.ranking.llm_ranker.model_name = (
|
||||
rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
)
|
||||
|
||||
query = aiplatform_v1.RagQuery(
|
||||
text=text,
|
||||
rag_retrieval_config=api_retrieval_config,
|
||||
)
|
||||
request = aiplatform_v1.RetrieveContextsRequest(
|
||||
vertex_rag_store=vertex_rag_store,
|
||||
parent=parent,
|
||||
query=query,
|
||||
)
|
||||
try:
|
||||
response = client.retrieve_contexts(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in retrieving contexts due to: ", e) from e
|
||||
|
||||
return response
|
||||
168
.venv/lib/python3.10/site-packages/vertexai/rag/rag_store.py
Normal file
168
.venv/lib/python3.10/site-packages/vertexai/rag/rag_store.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""RAG retrieval tool for content generation."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from google.cloud import aiplatform_v1beta1
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
|
||||
from vertexai import generative_models
|
||||
from vertexai.rag.utils import _gapic_utils
|
||||
from vertexai.rag.utils import resources
|
||||
|
||||
|
||||
class Retrieval(generative_models.grounding.Retrieval):
|
||||
"""Defines a retrieval tool that a model can call to access external knowledge."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Union["VertexRagStore"],
|
||||
disable_attribution: Optional[bool] = False,
|
||||
):
|
||||
self._raw_retrieval = gapic_tool_types.Retrieval(
|
||||
vertex_rag_store=source._raw_vertex_rag_store,
|
||||
disable_attribution=disable_attribution,
|
||||
)
|
||||
|
||||
|
||||
class VertexRagStore:
|
||||
"""Retrieve from Vertex RAG Store."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rag_resources: Optional[List[resources.RagResource]] = None,
|
||||
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
|
||||
):
|
||||
"""Initializes a Vertex RAG store tool.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
config = vertexai.rag.RagRetrievalConfig(
|
||||
top_k=2,
|
||||
filter=vertexai.rag.RagRetrievalConfig.Filter(
|
||||
vector_distance_threshold=0.5
|
||||
),
|
||||
ranking=vertex.rag.Ranking(
|
||||
llm_ranker=vertexai.rag.LlmRanker(
|
||||
model_name="gemini-1.5-flash-002"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tool = Tool.from_retrieval(
|
||||
retrieval=vertexai.rag.Retrieval(
|
||||
source=vertexai.rag.VertexRagStore(
|
||||
rag_corpora=["projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1"],
|
||||
rag_retrieval_config=config,
|
||||
),
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
rag_resources: List of RagResource to retrieve from. It can be used
|
||||
to specify corpus only or ragfiles. Currently only support one
|
||||
corpus or multiple files from one corpus. In the future we
|
||||
may open up multiple corpora support.
|
||||
rag_retrieval_config: Optional. The config containing the retrieval
|
||||
parameters, including similarity_top_k and vector_distance_threshold.
|
||||
"""
|
||||
|
||||
if rag_resources:
|
||||
if len(rag_resources) > 1:
|
||||
raise ValueError("Currently only support 1 RagResource.")
|
||||
name = rag_resources[0].rag_corpus
|
||||
else:
|
||||
raise ValueError("rag_resources must be specified.")
|
||||
|
||||
data_client = _gapic_utils.create_rag_data_service_client()
|
||||
if data_client.parse_rag_corpus_path(name):
|
||||
rag_corpus_name = name
|
||||
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
|
||||
parent = initializer.global_config.common_location_path()
|
||||
rag_corpus_name = parent + "/ragCorpora/" + name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid RagCorpus name: {name}. Proper format should be:"
|
||||
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
|
||||
)
|
||||
|
||||
# If rag_retrieval_config is not specified, set it to default values.
|
||||
api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig()
|
||||
# If rag_retrieval_config is specified, populate the default config.
|
||||
if rag_retrieval_config:
|
||||
api_retrieval_config.top_k = rag_retrieval_config.top_k
|
||||
# Set vector_distance_threshold to config value if specified
|
||||
if rag_retrieval_config.filter:
|
||||
# Check if both vector_distance_threshold and
|
||||
# vector_similarity_threshold are specified.
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of vector_distance_threshold or"
|
||||
" vector_similarity_threshold can be specified at a time"
|
||||
" in rag_retrieval_config."
|
||||
)
|
||||
api_retrieval_config.filter.vector_distance_threshold = (
|
||||
rag_retrieval_config.filter.vector_distance_threshold
|
||||
)
|
||||
api_retrieval_config.filter.vector_similarity_threshold = (
|
||||
rag_retrieval_config.filter.vector_similarity_threshold
|
||||
)
|
||||
# Check if both rank_service and llm_ranker are specified.
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
and rag_retrieval_config.ranking.rank_service.model_name
|
||||
and rag_retrieval_config.ranking.llm_ranker
|
||||
and rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of rank_service or llm_ranker can be specified"
|
||||
" at a time in rag_retrieval_config."
|
||||
)
|
||||
# Set rank_service to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
):
|
||||
api_retrieval_config.ranking.rank_service.model_name = (
|
||||
rag_retrieval_config.ranking.rank_service.model_name
|
||||
)
|
||||
# Set llm_ranker to config value if specified
|
||||
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
|
||||
api_retrieval_config.ranking.llm_ranker.model_name = (
|
||||
rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
)
|
||||
|
||||
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
|
||||
rag_corpus=rag_corpus_name,
|
||||
rag_file_ids=rag_resources[0].rag_file_ids,
|
||||
)
|
||||
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
|
||||
rag_resources=[gapic_rag_resource],
|
||||
rag_retrieval_config=api_retrieval_config,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,656 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
from google.cloud.aiplatform_v1.types import api_auth
|
||||
from google.cloud.aiplatform_v1 import (
|
||||
RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig,
|
||||
GoogleDriveSource,
|
||||
ImportRagFilesConfig,
|
||||
ImportRagFilesRequest,
|
||||
RagFileChunkingConfig,
|
||||
RagFileParsingConfig,
|
||||
RagFileTransformationConfig,
|
||||
RagCorpus as GapicRagCorpus,
|
||||
RagFile as GapicRagFile,
|
||||
SharePointSources as GapicSharePointSources,
|
||||
SlackSource as GapicSlackSource,
|
||||
JiraSource as GapicJiraSource,
|
||||
RagVectorDbConfig as GapicRagVectorDbConfig,
|
||||
VertexAiSearchConfig as GapicVertexAiSearchConfig,
|
||||
)
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.utils import (
|
||||
VertexRagDataAsyncClientWithOverride,
|
||||
VertexRagDataClientWithOverride,
|
||||
VertexRagClientWithOverride,
|
||||
)
|
||||
from vertexai.rag.utils.resources import (
|
||||
LayoutParserConfig,
|
||||
Pinecone,
|
||||
RagCorpus,
|
||||
RagEmbeddingModelConfig,
|
||||
RagFile,
|
||||
RagManagedDb,
|
||||
RagVectorDbConfig,
|
||||
SharePointSources,
|
||||
SlackChannelsSource,
|
||||
TransformationConfig,
|
||||
JiraSource,
|
||||
VertexAiSearchConfig,
|
||||
VertexVectorSearch,
|
||||
VertexPredictionEndpoint,
|
||||
)
|
||||
|
||||
|
||||
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
|
||||
_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = (
|
||||
r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?"
|
||||
)
|
||||
|
||||
|
||||
def create_rag_data_service_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagDataClientWithOverride,
|
||||
).select_version("v1")
|
||||
|
||||
|
||||
def create_rag_data_service_async_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagDataAsyncClientWithOverride,
|
||||
).select_version("v1")
|
||||
|
||||
|
||||
def create_rag_service_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagClientWithOverride,
|
||||
).select_version("v1")
|
||||
|
||||
|
||||
def convert_gapic_to_rag_embedding_model_config(
|
||||
gapic_embedding_model_config: GapicRagEmbeddingModelConfig,
|
||||
) -> RagEmbeddingModelConfig:
|
||||
"""Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig."""
|
||||
embedding_model_config = RagEmbeddingModelConfig()
|
||||
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
publisher_model = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
|
||||
path,
|
||||
)
|
||||
endpoint = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
|
||||
path,
|
||||
)
|
||||
if publisher_model:
|
||||
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
|
||||
publisher_model=path
|
||||
)
|
||||
if endpoint:
|
||||
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
|
||||
endpoint=path,
|
||||
model=gapic_embedding_model_config.vertex_prediction_endpoint.model,
|
||||
model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id,
|
||||
)
|
||||
return embedding_model_config
|
||||
|
||||
|
||||
def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("weaviate")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.weaviate.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("rag_managed_db")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.rag_managed_db.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("vertex_feature_store")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.vertex_feature_store.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("pinecone")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.pinecone.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("vertex_vector_search")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.vertex_vector_search.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_rag_embedding_model_config(
|
||||
gapic_vector_db: GapicRagVectorDbConfig,
|
||||
) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("rag_embedding_model_config")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0
|
||||
|
||||
|
||||
def convert_gapic_to_backend_config(
|
||||
gapic_vector_db: GapicRagVectorDbConfig,
|
||||
) -> RagVectorDbConfig:
|
||||
"""Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb."""
|
||||
vector_config = RagVectorDbConfig()
|
||||
if _check_pinecone(gapic_vector_db):
|
||||
vector_config.vector_db = Pinecone(
|
||||
index_name=gapic_vector_db.pinecone.index_name,
|
||||
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
|
||||
)
|
||||
elif _check_vertex_vector_search(gapic_vector_db):
|
||||
vector_config.vector_db = VertexVectorSearch(
|
||||
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
|
||||
index=gapic_vector_db.vertex_vector_search.index,
|
||||
)
|
||||
elif _check_rag_managed_db(gapic_vector_db):
|
||||
vector_config.vector_db = RagManagedDb()
|
||||
if _check_rag_embedding_model_config(gapic_vector_db):
|
||||
vector_config.rag_embedding_model_config = (
|
||||
convert_gapic_to_rag_embedding_model_config(
|
||||
gapic_vector_db.rag_embedding_model_config
|
||||
)
|
||||
)
|
||||
return vector_config
|
||||
|
||||
|
||||
def convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_vertex_ai_search_config: VertexAiSearchConfig,
|
||||
) -> VertexAiSearchConfig:
|
||||
"""Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig."""
|
||||
if gapic_vertex_ai_search_config.serving_config:
|
||||
return VertexAiSearchConfig(
|
||||
serving_config=gapic_vertex_ai_search_config.serving_config,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
|
||||
"""Convert GapicRagCorpus to RagCorpus."""
|
||||
rag_corpus = RagCorpus(
|
||||
name=gapic_rag_corpus.name,
|
||||
display_name=gapic_rag_corpus.display_name,
|
||||
description=gapic_rag_corpus.description,
|
||||
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_rag_corpus.vertex_ai_search_config
|
||||
),
|
||||
backend_config=convert_gapic_to_backend_config(
|
||||
gapic_rag_corpus.vector_db_config
|
||||
),
|
||||
)
|
||||
return rag_corpus
|
||||
|
||||
|
||||
def convert_gapic_to_rag_corpus_no_embedding_model_config(
|
||||
gapic_rag_corpus: GapicRagCorpus,
|
||||
) -> RagCorpus:
|
||||
"""Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus."""
|
||||
rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config
|
||||
rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None
|
||||
rag_corpus = RagCorpus(
|
||||
name=gapic_rag_corpus.name,
|
||||
display_name=gapic_rag_corpus.display_name,
|
||||
description=gapic_rag_corpus.description,
|
||||
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_rag_corpus.vertex_ai_search_config
|
||||
),
|
||||
backend_config=convert_gapic_to_backend_config(
|
||||
rag_vector_db_config_no_embedding_model_config
|
||||
),
|
||||
)
|
||||
return rag_corpus
|
||||
|
||||
|
||||
def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
|
||||
"""Convert GapicRagFile to RagFile."""
|
||||
rag_file = RagFile(
|
||||
name=gapic_rag_file.name,
|
||||
display_name=gapic_rag_file.display_name,
|
||||
description=gapic_rag_file.description,
|
||||
)
|
||||
return rag_file
|
||||
|
||||
|
||||
def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile:
|
||||
"""Converts a JSON response to a RagFile."""
|
||||
rag_file = RagFile(
|
||||
name=upload_rag_file_response.get("ragFile").get("name"),
|
||||
display_name=upload_rag_file_response.get("ragFile").get("displayName"),
|
||||
description=upload_rag_file_response.get("ragFile").get("description"),
|
||||
)
|
||||
return rag_file
|
||||
|
||||
|
||||
def convert_path_to_resource_id(
|
||||
path: str,
|
||||
) -> Union[str, GoogleDriveSource.ResourceId]:
|
||||
"""Converts a path to a Google Cloud storage uri or GoogleDriveSource.ResourceId."""
|
||||
if path.startswith("gs://"):
|
||||
# Google Cloud Storage source
|
||||
return path
|
||||
elif path.startswith("https://drive.google.com/"):
|
||||
# Google Drive source
|
||||
path_list = path.split("/")
|
||||
if "file" in path_list:
|
||||
index = path_list.index("file") + 2
|
||||
resource_id = path_list[index].split("?")[0]
|
||||
resource_type = GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE
|
||||
elif "folders" in path_list:
|
||||
index = path_list.index("folders") + 1
|
||||
resource_id = path_list[index].split("?")[0]
|
||||
resource_type = (
|
||||
GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER
|
||||
)
|
||||
else:
|
||||
raise ValueError("path %s is not a valid Google Drive url.", path)
|
||||
|
||||
return GoogleDriveSource.ResourceId(
|
||||
resource_id=resource_id,
|
||||
resource_type=resource_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"path must be a Google Cloud Storage uri or a Google Drive url."
|
||||
)
|
||||
|
||||
|
||||
def convert_source_for_rag_import(
|
||||
source: Union[SlackChannelsSource, JiraSource, SharePointSources]
|
||||
) -> Union[GapicSlackSource, GapicJiraSource]:
|
||||
"""Converts a SlackChannelsSource or JiraSource to a GapicSlackSource or GapicJiraSource."""
|
||||
if isinstance(source, SlackChannelsSource):
|
||||
result_source_channels = []
|
||||
for channel in source.channels:
|
||||
api_key = channel.api_key
|
||||
cid = channel.channel_id
|
||||
start_time = channel.start_time
|
||||
end_time = channel.end_time
|
||||
result_channels = GapicSlackSource.SlackChannels(
|
||||
channels=[
|
||||
GapicSlackSource.SlackChannels.SlackChannel(
|
||||
channel_id=cid,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
],
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
)
|
||||
result_source_channels.append(result_channels)
|
||||
return GapicSlackSource(
|
||||
channels=result_source_channels,
|
||||
)
|
||||
elif isinstance(source, JiraSource):
|
||||
result_source_queries = []
|
||||
for query in source.queries:
|
||||
api_key = query.api_key
|
||||
custom_queries = query.custom_queries
|
||||
projects = query.jira_projects
|
||||
email = query.email
|
||||
server_uri = query.server_uri
|
||||
result_query = GapicJiraSource.JiraQueries(
|
||||
custom_queries=custom_queries,
|
||||
projects=projects,
|
||||
email=email,
|
||||
server_uri=server_uri,
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
)
|
||||
result_source_queries.append(result_query)
|
||||
return GapicJiraSource(
|
||||
jira_queries=result_source_queries,
|
||||
)
|
||||
elif isinstance(source, SharePointSources):
|
||||
result_source_share_point_sources = []
|
||||
for share_point_source in source.share_point_sources:
|
||||
sharepoint_folder_path = share_point_source.sharepoint_folder_path
|
||||
sharepoint_folder_id = share_point_source.sharepoint_folder_id
|
||||
drive_name = share_point_source.drive_name
|
||||
drive_id = share_point_source.drive_id
|
||||
client_id = share_point_source.client_id
|
||||
client_secret = share_point_source.client_secret
|
||||
tenant_id = share_point_source.tenant_id
|
||||
sharepoint_site_name = share_point_source.sharepoint_site_name
|
||||
result_share_point_source = GapicSharePointSources.SharePointSource(
|
||||
client_id=client_id,
|
||||
client_secret=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=client_secret
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
sharepoint_site_name=sharepoint_site_name,
|
||||
)
|
||||
if sharepoint_folder_path is not None and sharepoint_folder_id is not None:
|
||||
raise ValueError(
|
||||
"sharepoint_folder_path and sharepoint_folder_id cannot both be set."
|
||||
)
|
||||
elif sharepoint_folder_path is not None:
|
||||
result_share_point_source.sharepoint_folder_path = (
|
||||
sharepoint_folder_path
|
||||
)
|
||||
elif sharepoint_folder_id is not None:
|
||||
result_share_point_source.sharepoint_folder_id = sharepoint_folder_id
|
||||
if drive_name is not None and drive_id is not None:
|
||||
raise ValueError("drive_name and drive_id cannot both be set.")
|
||||
elif drive_name is not None:
|
||||
result_share_point_source.drive_name = drive_name
|
||||
elif drive_id is not None:
|
||||
result_share_point_source.drive_id = drive_id
|
||||
else:
|
||||
raise ValueError("Either drive_name and drive_id must be set.")
|
||||
result_source_share_point_sources.append(result_share_point_source)
|
||||
return GapicSharePointSources(
|
||||
share_point_sources=result_source_share_point_sources,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"source must be a SlackChannelsSource or JiraSource or SharePointSources."
|
||||
)
|
||||
|
||||
|
||||
def prepare_import_files_request(
|
||||
corpus_name: str,
|
||||
paths: Optional[Sequence[str]] = None,
|
||||
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
|
||||
transformation_config: Optional[TransformationConfig] = None,
|
||||
max_embedding_requests_per_min: int = 1000,
|
||||
import_result_sink: Optional[str] = None,
|
||||
partial_failures_sink: Optional[str] = None,
|
||||
parser: Optional[LayoutParserConfig] = None,
|
||||
) -> ImportRagFilesRequest:
|
||||
if len(corpus_name.split("/")) != 6:
|
||||
raise ValueError(
|
||||
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
|
||||
)
|
||||
|
||||
rag_file_parsing_config = RagFileParsingConfig()
|
||||
if parser is not None:
|
||||
if (
|
||||
re.fullmatch(_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, parser.processor_name)
|
||||
is None
|
||||
):
|
||||
raise ValueError(
|
||||
"processor_name must be of the format "
|
||||
"`projects/{project_id}/locations/{location}/processors/{processor_id}`"
|
||||
"or "
|
||||
"`projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`, "
|
||||
f"got {parser.processor_name!r}"
|
||||
)
|
||||
rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser(
|
||||
processor_name=parser.processor_name,
|
||||
max_parsing_requests_per_min=parser.max_parsing_requests_per_min,
|
||||
)
|
||||
|
||||
chunk_size = 1024
|
||||
chunk_overlap = 200
|
||||
if transformation_config and transformation_config.chunking_config:
|
||||
chunk_size = transformation_config.chunking_config.chunk_size
|
||||
chunk_overlap = transformation_config.chunking_config.chunk_overlap
|
||||
|
||||
rag_file_transformation_config = RagFileTransformationConfig(
|
||||
rag_file_chunking_config=RagFileChunkingConfig(
|
||||
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
import_rag_files_config = ImportRagFilesConfig(
|
||||
rag_file_transformation_config=rag_file_transformation_config,
|
||||
rag_file_parsing_config=rag_file_parsing_config,
|
||||
max_embedding_requests_per_min=max_embedding_requests_per_min,
|
||||
)
|
||||
|
||||
import_result_sink = import_result_sink or partial_failures_sink
|
||||
|
||||
if import_result_sink is not None:
|
||||
if import_result_sink.startswith("gs://"):
|
||||
import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = (
|
||||
import_result_sink
|
||||
)
|
||||
elif import_result_sink.startswith("bq://"):
|
||||
import_rag_files_config.partial_failure_bigquery_sink.output_uri = (
|
||||
import_result_sink
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"import_result_sink must be a GCS path or a BigQuery table."
|
||||
)
|
||||
|
||||
if source is not None:
|
||||
gapic_source = convert_source_for_rag_import(source)
|
||||
if isinstance(gapic_source, GapicSlackSource):
|
||||
import_rag_files_config.slack_source = gapic_source
|
||||
if isinstance(gapic_source, GapicJiraSource):
|
||||
import_rag_files_config.jira_source = gapic_source
|
||||
if isinstance(gapic_source, GapicSharePointSources):
|
||||
import_rag_files_config.share_point_sources = gapic_source
|
||||
else:
|
||||
uris = []
|
||||
resource_ids = []
|
||||
for p in paths:
|
||||
output = convert_path_to_resource_id(p)
|
||||
if isinstance(output, str):
|
||||
uris.append(p)
|
||||
else:
|
||||
resource_ids.append(output)
|
||||
if uris:
|
||||
import_rag_files_config.gcs_source.uris = uris
|
||||
if resource_ids:
|
||||
google_drive_source = GoogleDriveSource(
|
||||
resource_ids=resource_ids,
|
||||
)
|
||||
import_rag_files_config.google_drive_source = google_drive_source
|
||||
|
||||
request = ImportRagFilesRequest(
|
||||
parent=corpus_name, import_rag_files_config=import_rag_files_config
|
||||
)
|
||||
return request
|
||||
|
||||
|
||||
def get_corpus_name(
|
||||
name: str,
|
||||
) -> str:
|
||||
if name:
|
||||
client = create_rag_data_service_client()
|
||||
if client.parse_rag_corpus_path(name):
|
||||
return name
|
||||
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
|
||||
return client.rag_corpus_path(
|
||||
project=initializer.global_config.project,
|
||||
location=initializer.global_config.location,
|
||||
rag_corpus=name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def get_file_name(
|
||||
name: str,
|
||||
corpus_name: str,
|
||||
) -> str:
|
||||
client = create_rag_data_service_client()
|
||||
if client.parse_rag_file_path(name):
|
||||
return name
|
||||
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
|
||||
if not corpus_name:
|
||||
raise ValueError(
|
||||
"corpus_name must be provided if name is a `{rag_file}`, not a "
|
||||
"full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). "
|
||||
)
|
||||
return client.rag_file_path(
|
||||
project=initializer.global_config.project,
|
||||
location=initializer.global_config.location,
|
||||
rag_corpus=get_corpus_name(corpus_name),
|
||||
rag_file=name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
|
||||
)
|
||||
|
||||
|
||||
def set_embedding_model_config(
|
||||
embedding_model_config: RagEmbeddingModelConfig,
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
if embedding_model_config.vertex_prediction_endpoint is None:
|
||||
return
|
||||
if (
|
||||
embedding_model_config.vertex_prediction_endpoint.publisher_model
|
||||
and embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
):
|
||||
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
|
||||
if (
|
||||
not embedding_model_config.vertex_prediction_endpoint.publisher_model
|
||||
and not embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
):
|
||||
raise ValueError("At least one of publisher_model and endpoint must be set.")
|
||||
parent = initializer.global_config.common_location_path(project=None, location=None)
|
||||
|
||||
if embedding_model_config.vertex_prediction_endpoint.publisher_model:
|
||||
publisher_model = (
|
||||
embedding_model_config.vertex_prediction_endpoint.publisher_model
|
||||
)
|
||||
full_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
|
||||
publisher_model,
|
||||
)
|
||||
resource_name = re.match(
|
||||
r"^publishers/google/models/(?P<model_id>.+?)$",
|
||||
publisher_model,
|
||||
)
|
||||
if full_resource_name:
|
||||
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
publisher_model
|
||||
)
|
||||
elif resource_name:
|
||||
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
parent + "/" + publisher_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
|
||||
)
|
||||
|
||||
if embedding_model_config.vertex_prediction_endpoint.endpoint:
|
||||
endpoint = embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
full_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
|
||||
endpoint,
|
||||
)
|
||||
resource_name = re.match(
|
||||
r"^endpoints/(?P<endpoint>.+?)$",
|
||||
endpoint,
|
||||
)
|
||||
if full_resource_name:
|
||||
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
endpoint
|
||||
)
|
||||
elif resource_name:
|
||||
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
parent + "/" + endpoint
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
|
||||
)
|
||||
|
||||
|
||||
def set_backend_config(
|
||||
backend_config: Optional[
|
||||
Union[
|
||||
RagVectorDbConfig,
|
||||
None,
|
||||
]
|
||||
],
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
"""Sets the vector db configuration for the rag corpus."""
|
||||
if backend_config is None:
|
||||
return
|
||||
|
||||
if backend_config.vector_db is not None:
|
||||
vector_config = backend_config.vector_db
|
||||
if vector_config is None or isinstance(vector_config, RagManagedDb):
|
||||
rag_corpus.vector_db_config.rag_managed_db.CopyFrom(
|
||||
GapicRagVectorDbConfig.RagManagedDb()
|
||||
)
|
||||
elif isinstance(vector_config, VertexVectorSearch):
|
||||
index_endpoint = vector_config.index_endpoint
|
||||
index = vector_config.index
|
||||
|
||||
rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = (
|
||||
index_endpoint
|
||||
)
|
||||
rag_corpus.vector_db_config.vertex_vector_search.index = index
|
||||
elif isinstance(vector_config, Pinecone):
|
||||
index_name = vector_config.index_name
|
||||
api_key = vector_config.api_key
|
||||
|
||||
rag_corpus.vector_db_config.pinecone.index_name = index_name
|
||||
rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = (
|
||||
api_key
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"backend_config must be a VertexFeatureStore,"
|
||||
"RagManagedDb, or Pinecone."
|
||||
)
|
||||
if backend_config.rag_embedding_model_config:
|
||||
set_embedding_model_config(
|
||||
backend_config.rag_embedding_model_config, rag_corpus
|
||||
)
|
||||
|
||||
|
||||
def set_vertex_ai_search_config(
|
||||
vertex_ai_search_config: VertexAiSearchConfig,
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
if not vertex_ai_search_config.serving_config:
|
||||
raise ValueError("serving_config must be set.")
|
||||
engine_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/engines/(?P<engine>.+?)/servingConfigs/(?P<serving_config>.+?)$",
|
||||
vertex_ai_search_config.serving_config,
|
||||
)
|
||||
data_store_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/dataStores/(?P<data_store>.+?)/servingConfigs/(?P<serving_config>.+?)$",
|
||||
vertex_ai_search_config.serving_config,
|
||||
)
|
||||
if engine_resource_name or data_store_resource_name:
|
||||
rag_corpus.vertex_ai_search_config = GapicVertexAiSearchConfig(
|
||||
serving_config=vertex_ai_search_config.serving_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
|
||||
)
|
||||
@@ -0,0 +1,447 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagFile:
|
||||
"""RAG file (output only).
|
||||
|
||||
Attributes:
|
||||
name: Generated resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}/ragFiles/{rag_file}``
|
||||
display_name: Display name that was configured at client side.
|
||||
description: The description of the RagFile.
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexPredictionEndpoint:
|
||||
"""VertexPredictionEndpoint.
|
||||
|
||||
Attributes:
|
||||
publisher_model: 1P publisher model resource name. Format:
|
||||
``publishers/google/models/{model}`` or
|
||||
``projects/{project}/locations/{location}/publishers/google/models/{model}``
|
||||
endpoint: 1P fine tuned embedding model resource name. Format:
|
||||
``endpoints/{endpoint}`` or
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``.
|
||||
model:
|
||||
Output only. The resource name of the model that is deployed
|
||||
on the endpoint. Present only when the endpoint is not a
|
||||
publisher model. Pattern:
|
||||
``projects/{project}/locations/{location}/models/{model}``
|
||||
model_version_id:
|
||||
Output only. Version ID of the model that is
|
||||
deployed on the endpoint. Present only when the
|
||||
endpoint is not a publisher model.
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
publisher_model: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
model_version_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagEmbeddingModelConfig:
|
||||
"""RagEmbeddingModelConfig.
|
||||
|
||||
Attributes:
|
||||
vertex_prediction_endpoint: The Vertex AI Prediction Endpoint resource
|
||||
name. Format:
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``
|
||||
"""
|
||||
|
||||
vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Weaviate:
|
||||
"""Weaviate.
|
||||
|
||||
Attributes:
|
||||
weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint
|
||||
collection_name: The corresponding Weaviate collection this corpus maps to
|
||||
api_key: The SecretManager resource name for the Weaviate DB API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
"""
|
||||
|
||||
weaviate_http_endpoint: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexFeatureStore:
|
||||
"""VertexFeatureStore.
|
||||
|
||||
Attributes:
|
||||
resource_name: The resource name of the FeatureView. Format:
|
||||
``projects/{project}/locations/{location}/featureOnlineStores/
|
||||
{feature_online_store}/featureViews/{feature_view}``
|
||||
"""
|
||||
|
||||
resource_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexVectorSearch:
|
||||
"""VertexVectorSearch.
|
||||
|
||||
Attributes:
|
||||
index_endpoint (str):
|
||||
The resource name of the Index Endpoint. Format:
|
||||
``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}``
|
||||
index (str):
|
||||
The resource name of the Index. Format:
|
||||
``projects/{project}/locations/{location}/indexes/{index}``
|
||||
"""
|
||||
|
||||
index_endpoint: Optional[str] = None
|
||||
index: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagManagedDb:
|
||||
"""RagManagedDb."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Pinecone:
|
||||
"""Pinecone.
|
||||
|
||||
Attributes:
|
||||
index_name: The Pinecone index name.
|
||||
api_key: The SecretManager resource name for the Pinecone DB API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
"""
|
||||
|
||||
index_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexAiSearchConfig:
|
||||
"""VertexAiSearchConfig.
|
||||
|
||||
Attributes:
|
||||
serving_config: The resource name of the Vertex AI Search serving config.
|
||||
Format:
|
||||
``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}``
|
||||
or
|
||||
``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}``
|
||||
"""
|
||||
|
||||
serving_config: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagVectorDbConfig:
|
||||
"""RagVectorDbConfig.
|
||||
|
||||
Attributes:
|
||||
vector_db: Can be one of the following: RagManagedDb, Pinecone,
|
||||
VertexVectorSearch.
|
||||
rag_embedding_model_config: The embedding model config of the Vector DB.
|
||||
"""
|
||||
|
||||
vector_db: Optional[
|
||||
Union[
|
||||
VertexVectorSearch,
|
||||
Pinecone,
|
||||
RagManagedDb,
|
||||
]
|
||||
] = None
|
||||
rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagCorpus:
|
||||
"""RAG corpus(output only).
|
||||
|
||||
Attributes:
|
||||
name: Generated resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
||||
display_name: Display name that was configured at client side.
|
||||
description: The description of the RagCorpus.
|
||||
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
|
||||
backend_config: The backend config of the RagCorpus. It can be a data
|
||||
store and/or retrieval engine.
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None
|
||||
backend_config: Optional[
|
||||
Union[
|
||||
RagVectorDbConfig,
|
||||
None,
|
||||
]
|
||||
] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagResource:
|
||||
"""RagResource.
|
||||
|
||||
The representation of the rag source. It can be used to specify corpus only
|
||||
or ragfiles. Currently only support one corpus or multiple files from one
|
||||
corpus. In the future we may open up multiple corpora support.
|
||||
|
||||
Attributes:
|
||||
rag_corpus: A Rag corpus resource name or corpus id. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
||||
or ``{rag_corpus_id}``.
|
||||
rag_files_id: List of Rag file resource name or file ids in the same corpus. Format:
|
||||
``{rag_file}``.
|
||||
"""
|
||||
|
||||
rag_corpus: Optional[str] = None
|
||||
rag_file_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SlackChannel:
|
||||
"""SlackChannel.
|
||||
|
||||
Attributes:
|
||||
channel_id: The Slack channel ID.
|
||||
api_key: The SecretManager resource name for the Slack API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
See: https://api.slack.com/tutorials/tracks/getting-a-token.
|
||||
start_time: The starting timestamp for messages to import.
|
||||
end_time: The ending timestamp for messages to import.
|
||||
"""
|
||||
|
||||
channel_id: str
|
||||
api_key: str
|
||||
start_time: Optional[timestamp_pb2.Timestamp] = None
|
||||
end_time: Optional[timestamp_pb2.Timestamp] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SlackChannelsSource:
|
||||
"""SlackChannelsSource.
|
||||
|
||||
Attributes:
|
||||
channels: The Slack channels.
|
||||
"""
|
||||
|
||||
channels: Sequence[SlackChannel]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JiraQuery:
|
||||
"""JiraQuery.
|
||||
|
||||
Attributes:
|
||||
email: The Jira email address.
|
||||
jira_projects: A list of Jira projects to import in their entirety.
|
||||
custom_queries: A list of custom JQL Jira queries to import.
|
||||
api_key: The SecretManager version resource name for Jira API access. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
See: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/
|
||||
server_uri: The Jira server URI. Format:
|
||||
``{server}.atlassian.net``
|
||||
"""
|
||||
|
||||
email: str
|
||||
jira_projects: Sequence[str]
|
||||
custom_queries: Sequence[str]
|
||||
api_key: str
|
||||
server_uri: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JiraSource:
|
||||
"""JiraSource.
|
||||
|
||||
Attributes:
|
||||
queries: The Jira queries.
|
||||
"""
|
||||
|
||||
queries: Sequence[JiraQuery]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SharePointSource:
|
||||
"""SharePointSource.
|
||||
|
||||
Attributes:
|
||||
sharepoint_folder_path: The path of the SharePoint folder to download
|
||||
from.
|
||||
sharepoint_folder_id: The ID of the SharePoint folder to download
|
||||
from.
|
||||
drive_name: The name of the drive to download from.
|
||||
drive_id: The ID of the drive to download from.
|
||||
client_id: The Application ID for the app registered in
|
||||
Microsoft Azure Portal. The application must
|
||||
also be configured with MS Graph permissions
|
||||
"Files.ReadAll", "Sites.ReadAll" and
|
||||
BrowserSiteLists.Read.All.
|
||||
client_secret: The application secret for the app registered
|
||||
in Azure.
|
||||
tenant_id: Unique identifier of the Azure Active
|
||||
Directory Instance.
|
||||
sharepoint_site_name: The name of the SharePoint site to download
|
||||
from. This can be the site name or the site id.
|
||||
"""
|
||||
|
||||
sharepoint_folder_path: Optional[str] = None
|
||||
sharepoint_folder_id: Optional[str] = None
|
||||
drive_name: Optional[str] = None
|
||||
drive_id: Optional[str] = None
|
||||
client_id: str = None
|
||||
client_secret: str = None
|
||||
tenant_id: str = None
|
||||
sharepoint_site_name: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SharePointSources:
|
||||
"""SharePointSources.
|
||||
|
||||
Attributes:
|
||||
share_point_sources: The SharePoint sources.
|
||||
"""
|
||||
|
||||
share_point_sources: Sequence[SharePointSource]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Filter:
|
||||
"""Filter.
|
||||
|
||||
Attributes:
|
||||
vector_distance_threshold: Only returns contexts with vector
|
||||
distance smaller than the threshold.
|
||||
vector_similarity_threshold: Only returns contexts with vector
|
||||
similarity larger than the threshold.
|
||||
metadata_filter: String for metadata filtering.
|
||||
"""
|
||||
|
||||
vector_distance_threshold: Optional[float] = None
|
||||
vector_similarity_threshold: Optional[float] = None
|
||||
metadata_filter: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LlmRanker:
|
||||
"""LlmRanker.
|
||||
|
||||
Attributes:
|
||||
model_name: The model name used for ranking. Only Gemini models are
|
||||
supported for now.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RankService:
|
||||
"""RankService.
|
||||
|
||||
Attributes:
|
||||
model_name: The model name of the rank service. Format:
|
||||
``semantic-ranker-512@latest``
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Ranking:
|
||||
"""Ranking.
|
||||
|
||||
Attributes:
|
||||
rank_service: Config for Rank Service.
|
||||
llm_ranker: Config for LlmRanker.
|
||||
"""
|
||||
|
||||
rank_service: Optional[RankService] = None
|
||||
llm_ranker: Optional[LlmRanker] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagRetrievalConfig:
|
||||
"""RagRetrievalConfig.
|
||||
|
||||
Attributes:
|
||||
top_k: The number of contexts to retrieve.
|
||||
filter: Config for filters.
|
||||
ranking: Config for ranking.
|
||||
"""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
filter: Optional[Filter] = None
|
||||
ranking: Optional[Ranking] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChunkingConfig:
|
||||
"""ChunkingConfig.
|
||||
|
||||
Attributes:
|
||||
chunk_size: The size of each chunk.
|
||||
chunk_overlap: The size of the overlap between chunks.
|
||||
"""
|
||||
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TransformationConfig:
|
||||
"""TransformationConfig.
|
||||
|
||||
Attributes:
|
||||
chunking_config: The chunking config.
|
||||
"""
|
||||
|
||||
chunking_config: Optional[ChunkingConfig] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LayoutParserConfig:
|
||||
"""Configuration for the Document AI Layout Parser Processor.
|
||||
|
||||
Attributes:
|
||||
processor_name: The full resource name of a Document AI processor or
|
||||
processor version. The processor must have type
|
||||
`LAYOUT_PARSER_PROCESSOR`.
|
||||
Format must be one of the following:
|
||||
- `projects/{project_id}/locations/{location}/processors/{processor_id}`
|
||||
- `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`
|
||||
max_parsing_requests_per_min: The maximum number of requests the job is
|
||||
allowed to make to the Document AI processor per minute. Consult
|
||||
https://cloud.google.com/document-ai/quotas and the Quota page for
|
||||
your project to set an appropriate value here. If unspecified, a
|
||||
default value of 120 QPM will be used.
|
||||
"""
|
||||
|
||||
processor_name: str
|
||||
max_parsing_requests_per_min: Optional[int] = None
|
||||
Reference in New Issue
Block a user