structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vertexai.rag.rag_data import (
create_corpus,
update_corpus,
list_corpora,
get_corpus,
delete_corpus,
upload_file,
import_files,
import_files_async,
get_file,
list_files,
delete_file,
)
from vertexai.rag.rag_retrieval import (
retrieval_query,
)
from vertexai.rag.rag_store import (
Retrieval,
VertexRagStore,
)
from vertexai.rag.utils.resources import (
ChunkingConfig,
Filter,
JiraQuery,
JiraSource,
LayoutParserConfig,
LlmRanker,
Pinecone,
RagCorpus,
RagEmbeddingModelConfig,
RagFile,
RagManagedDb,
RagResource,
RagRetrievalConfig,
RagVectorDbConfig,
Ranking,
RankService,
SharePointSource,
SharePointSources,
SlackChannel,
SlackChannelsSource,
TransformationConfig,
VertexAiSearchConfig,
VertexPredictionEndpoint,
VertexVectorSearch,
)
__all__ = (
"ChunkingConfig",
"Filter",
"JiraQuery",
"JiraSource",
"LayoutParserConfig",
"LlmRanker",
"Pinecone",
"RagCorpus",
"RagEmbeddingModelConfig",
"RagFile",
"RagManagedDb",
"RagResource",
"RagRetrievalConfig",
"RagVectorDbConfig",
"Ranking",
"RankService",
"Retrieval",
"SharePointSource",
"SharePointSources",
"SlackChannel",
"SlackChannelsSource",
"TransformationConfig",
"VertexAiSearchConfig",
"VertexRagStore",
"VertexPredictionEndpoint",
"VertexVectorSearch",
"create_corpus",
"delete_corpus",
"delete_file",
"get_corpus",
"get_file",
"import_files",
"import_files_async",
"list_corpora",
"list_files",
"retrieval_query",
"upload_file",
"update_corpus",
)

View File

@@ -0,0 +1,870 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""RAG data management SDK."""
from typing import Optional, Sequence, Union
from google import auth
from google.api_core import operation_async
from google.auth.transport import requests as google_auth_requests
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform_v1 import (
CreateRagCorpusRequest,
DeleteRagCorpusRequest,
DeleteRagFileRequest,
GetRagCorpusRequest,
GetRagFileRequest,
ImportRagFilesResponse,
ListRagCorporaRequest,
ListRagFilesRequest,
RagCorpus as GapicRagCorpus,
UpdateRagCorpusRequest,
)
from google.cloud.aiplatform_v1.services.vertex_rag_data_service.pagers import (
ListRagCorporaPager,
ListRagFilesPager,
)
from vertexai.rag.utils import (
_gapic_utils,
)
from vertexai.rag.utils.resources import (
JiraSource,
LayoutParserConfig,
RagCorpus,
RagFile,
RagVectorDbConfig,
SharePointSources,
SlackChannelsSource,
VertexAiSearchConfig,
TransformationConfig,
)
def create_corpus(
display_name: Optional[str] = None,
description: Optional[str] = None,
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
backend_config: Optional[
Union[
RagVectorDbConfig,
None,
]
] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Example usage:
```
import vertexai
from vertexai import rag
vertexai.init(project="my-project")
rag_corpus = rag.create_corpus(
display_name="my-corpus-1",
)
```
Args:
display_name: If not provided, SDK will create one. The display name of
the RagCorpus. The name can be up to 128 characters long and can
consist of any UTF-8 characters.
description: The description of the RagCorpus.
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
Note: backend_config cannot be set if vertex_ai_search_config is
specified.
backend_config: The backend config of the RagCorpus, specifying a
data store and/or embedding model.
Returns:
RagCorpus.
Raises:
RuntimeError: Failed in RagCorpus creation due to exception.
RuntimeError: Failed in RagCorpus creation due to operation error.
"""
if vertex_ai_search_config and backend_config:
raise ValueError(
"Only one of vertex_ai_search_config or backend_config can be set."
)
if not display_name:
display_name = "vertex-" + utils.timestamped_unique_name()
parent = initializer.global_config.common_location_path(project=None, location=None)
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
if backend_config:
_gapic_utils.set_backend_config(
backend_config=backend_config,
rag_corpus=rag_corpus,
)
elif vertex_ai_search_config:
_gapic_utils.set_vertex_ai_search_config(
vertex_ai_search_config=vertex_ai_search_config,
rag_corpus=rag_corpus,
)
request = CreateRagCorpusRequest(
parent=parent,
rag_corpus=rag_corpus,
)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.create_rag_corpus(request=request)
except Exception as e:
raise RuntimeError("Failed in RagCorpus creation due to: ", e) from e
return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=600))
def update_corpus(
corpus_name: str,
display_name: Optional[str] = None,
description: Optional[str] = None,
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
backend_config: Optional[
Union[
RagVectorDbConfig,
None,
]
] = None,
) -> RagCorpus:
"""Updates a RagCorpus resource. It is intended to update 3rd party vector
DBs (Vector Search, Vertex AI Feature Store, Weaviate, Pinecone) but not
Vertex RagManagedDb.
Example usage:
```
import vertexai
from vertexai import rag
vertexai.init(project="my-project")
rag_corpus = rag.update_corpus(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
display_name="my-corpus-1",
)
```
Args:
corpus_name: The name of the RagCorpus resource to update. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or
``{rag_corpus}``.
display_name: If not provided, the display name will not be updated. The
display name of the RagCorpus. The name can be up to 128 characters long
and can consist of any UTF-8 characters.
description: The description of the RagCorpus. If not provided, the
description will not be updated.
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
If not provided, the Vertex AI Search config will not be updated.
Note: backend_config cannot be set if vertex_ai_search_config is
specified.
backend_config: The backend config of the RagCorpus, specifying a
data store and/or embedding model.
Returns:
RagCorpus.
Raises:
RuntimeError: Failed in RagCorpus update due to exception.
RuntimeError: Failed in RagCorpus update due to operation error.
"""
if vertex_ai_search_config and backend_config:
raise ValueError(
"Only one of vertex_ai_search_config or backend_config can be set."
)
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
if display_name and description:
rag_corpus = GapicRagCorpus(
name=corpus_name, display_name=display_name, description=description
)
elif display_name:
rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name)
elif description:
rag_corpus = GapicRagCorpus(name=corpus_name, description=description)
else:
rag_corpus = GapicRagCorpus(name=corpus_name)
if backend_config:
_gapic_utils.set_backend_config(
backend_config=backend_config,
rag_corpus=rag_corpus,
)
if vertex_ai_search_config:
_gapic_utils.set_vertex_ai_search_config(
vertex_ai_search_config=vertex_ai_search_config,
rag_corpus=rag_corpus,
)
request = UpdateRagCorpusRequest(
rag_corpus=rag_corpus,
)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.update_rag_corpus(request=request)
except Exception as e:
raise RuntimeError("Failed in RagCorpus update due to: ", e) from e
return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config(
response.result(timeout=600)
)
def get_corpus(name: str) -> RagCorpus:
"""
Get an existing RagCorpus.
Args:
name: An existing RagCorpus resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
Returns:
RagCorpus.
"""
corpus_name = _gapic_utils.get_corpus_name(name)
request = GetRagCorpusRequest(name=corpus_name)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.get_rag_corpus(request=request)
except Exception as e:
raise RuntimeError("Failed in getting the RagCorpus due to: ", e) from e
return _gapic_utils.convert_gapic_to_rag_corpus(response)
def list_corpora(
page_size: Optional[int] = None, page_token: Optional[str] = None
) -> ListRagCorporaPager:
"""
List all RagCorpora in the same project and location.
Example usage:
```
import vertexai
from vertexai import rag
vertexai.init(project="my-project")
# List all corpora.
rag_corpora = list(rag.list_corpora())
# Alternatively, return a ListRagCorporaPager.
pager_1 = rag.list_corpora(page_size=10)
# Then get the next page, use the generated next_page_token from the last pager.
pager_2 = rag.list_corpora(page_size=10, page_token=pager_1.next_page_token)
```
Args:
page_size: The standard list page size. Leaving out the page_size
causes all of the results to be returned.
page_token: The standard list page token.
Returns:
ListRagCorporaPager.
"""
parent = initializer.global_config.common_location_path(project=None, location=None)
request = ListRagCorporaRequest(
parent=parent,
page_size=page_size,
page_token=page_token,
)
client = _gapic_utils.create_rag_data_service_client()
try:
pager = client.list_rag_corpora(request=request)
except Exception as e:
raise RuntimeError("Failed in listing the RagCorpora due to: ", e) from e
return pager
def delete_corpus(name: str) -> None:
"""
Delete an existing RagCorpus.
Args:
name: An existing RagCorpus resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
"""
corpus_name = _gapic_utils.get_corpus_name(name)
request = DeleteRagCorpusRequest(name=corpus_name)
client = _gapic_utils.create_rag_data_service_client()
try:
client.delete_rag_corpus(request=request)
print("Successfully deleted the RagCorpus.")
except Exception as e:
raise RuntimeError("Failed in RagCorpus deletion due to: ", e) from e
return None
def upload_file(
corpus_name: str,
path: Union[str, Sequence[str]],
display_name: Optional[str] = None,
description: Optional[str] = None,
transformation_config: Optional[TransformationConfig] = None,
) -> RagFile:
"""
Synchronous file upload to an existing RagCorpus.
Example usage:
```
import vertexai
from vertexai import rag
vertexai.init(project="my-project")
// Optional.
transformation_config = TransformationConfig(
chunking_config=ChunkingConfig(
chunk_size=1024,
chunk_overlap=200,
),
)
rag_file = rag.upload_file(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
display_name="my_file.txt",
path="usr/home/my_file.txt",
transformation_config=transformation_config,
)
```
Args:
corpus_name: The name of the RagCorpus resource into which to upload the file.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
path: A local file path. For example,
"usr/home/my_file.txt".
display_name: The display name of the data file.
description: The description of the RagFile.
transformation_config: The config for transforming the RagFile, like chunking.
Returns:
RagFile.
Raises:
RuntimeError: Failed in RagFile upload.
ValueError: RagCorpus is not found.
RuntimeError: Failed in indexing the RagFile.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
location = initializer.global_config.location
# GAPIC doesn't expose a path (scotty). Use requests API instead
if display_name is None:
display_name = "vertex-" + utils.timestamped_unique_name()
headers = {"X-Goog-Upload-Protocol": "multipart"}
if not initializer.global_config.api_endpoint:
request_endpoint = "{}-{}".format(
location, aiplatform.constants.base.API_BASE_PATH
)
else:
request_endpoint = initializer.global_config.api_endpoint
upload_request_uri = "https://{}/upload/v1/{}/ragFiles:upload".format(
request_endpoint,
corpus_name,
)
js_rag_file = {"rag_file": {"display_name": display_name}}
if description:
js_rag_file["rag_file"]["description"] = description
if transformation_config and transformation_config.chunking_config:
chunk_size = transformation_config.chunking_config.chunk_size
chunk_overlap = transformation_config.chunking_config.chunk_overlap
js_rag_file["upload_rag_file_config"] = {
"rag_file_transformation_config": {
"rag_file_chunking_config": {
"fixed_length_chunking": {
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
}
}
}
}
files = {
"metadata": (None, str(js_rag_file)),
"file": open(path, "rb"),
}
credentials, _ = auth.default()
authorized_session = google_auth_requests.AuthorizedSession(credentials=credentials)
try:
response = authorized_session.post(
url=upload_request_uri,
files=files,
headers=headers,
)
except Exception as e:
raise RuntimeError("Failed in uploading the RagFile due to: ", e) from e
if response.status_code == 404:
raise ValueError(
"RagCorpus '%s' is not found: %s", corpus_name, upload_request_uri
)
if response.json().get("error"):
raise RuntimeError(
"Failed in indexing the RagFile due to: ", response.json().get("error")
)
return _gapic_utils.convert_json_to_rag_file(response.json())
def import_files(
corpus_name: str,
paths: Optional[Sequence[str]] = None,
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
transformation_config: Optional[TransformationConfig] = None,
timeout: int = 600,
max_embedding_requests_per_min: int = 1000,
import_result_sink: Optional[str] = None,
partial_failures_sink: Optional[str] = None,
parser: Optional[LayoutParserConfig] = None,
) -> ImportRagFilesResponse:
"""
Import files to an existing RagCorpus, wait until completion.
Example usage:
```
import vertexai
from vertexai import rag
from google.protobuf import timestamp_pb2
vertexai.init(project="my-project")
# Google Drive example
paths = [
"https://drive.google.com/file/d/123",
"https://drive.google.com/drive/folders/456"
]
# Google Cloud Storage example
paths = ["gs://my_bucket/my_files_dir", ...]
transformation_config = TransformationConfig(
chunking_config=ChunkingConfig(
chunk_size=1024,
chunk_overlap=200,
),
)
response = rag.import_files(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
paths=paths,
transformation_config=transformation_config,
)
# Slack example
start_time = timestamp_pb2.Timestamp()
start_time.FromJsonString('2020-12-31T21:33:44Z')
end_time = timestamp_pb2.Timestamp()
end_time.GetCurrentTime()
source = rag.SlackChannelsSource(
channels = [
SlackChannel("channel1", "api_key1"),
SlackChannel("channel2", "api_key2", start_time, end_time)
],
)
# Jira Example
jira_query = rag.JiraQuery(
email="xxx@yyy.com",
jira_projects=["project1", "project2"],
custom_queries=["query1", "query2"],
api_key="api_key",
server_uri="server.atlassian.net"
)
source = rag.JiraSource(
queries=[jira_query],
)
response = rag.import_files(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
source=source,
transformation_config=transformation_config,
)
# SharePoint Example.
sharepoint_query = rag.SharePointSource(
sharepoint_folder_path="https://my-sharepoint-site.com/my-folder",
sharepoint_site_name="my-sharepoint-site.com",
client_id="my-client-id",
client_secret="my-client-secret",
tenant_id="my-tenant-id",
drive_id="my-drive-id",
)
source = rag.SharePointSources(
share_point_sources=[sharepoint_query],
)
# Return the number of imported RagFiles after completion.
print(response.imported_rag_files_count)
# Document AI Layout Parser example.
parser = LayoutParserConfig(
processor_name="projects/my-project/locations/us-central1/processors/my-processor-id",
max_parsing_requests_per_min=120,
)
response = rag.import_files(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
paths=paths,
parser=parser,
)
```
Args:
corpus_name: The name of the RagCorpus resource into which to import files.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
paths: A list of uris. Eligible uris will be Google Cloud Storage
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
(https://drive.google.com/file/... or folder
"https://drive.google.com/corp/drive/folders/...").
source: The source of the Slack or Jira import.
Must be either a SlackChannelsSource or JiraSource.
transformation_config: The config for transforming the imported
RagFiles.
max_embedding_requests_per_min:
Optional. The max number of queries per
minute that this job is allowed to make to the
embedding model specified on the corpus. This
value is specific to this job and not shared
across other import jobs. Consult the Quotas
page on the project to set an appropriate value
here. If unspecified, a default value of 1,000
QPM would be used.
timeout: Default is 600 seconds.
import_result_sink: Either a GCS path to store import results or a
BigQuery table to store import results. The format is
"gs://my-bucket/my/object.ndjson" for GCS or
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
object cannot be used. However, the BigQuery table may or may not
exist - if it does not exist, it will be created. If it does exist,
the schema will be checked and the import results will be appended
to the table.
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
Either a GCS path to store partial failures or a BigQuery table to
store partial failures. The format is
"gs://my-bucket/my/object.ndjson" for GCS or
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
object cannot be used. However, the BigQuery table may or may not
exist - if it does not exist, it will be created. If it does exist,
the schema will be checked and the partial failures will be appended
to the table.
parser: Document parser to use. Should be either None (default parser),
or a LayoutParserConfig (to parse documents using a Document AI
Layout Parser processor).
Returns:
ImportRagFilesResponse.
"""
if source is not None and paths is not None:
raise ValueError("Only one of source or paths must be passed in at a time")
if source is None and paths is None:
raise ValueError("One of source or paths must be passed in")
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = _gapic_utils.prepare_import_files_request(
corpus_name=corpus_name,
paths=paths,
source=source,
transformation_config=transformation_config,
max_embedding_requests_per_min=max_embedding_requests_per_min,
import_result_sink=import_result_sink,
partial_failures_sink=partial_failures_sink,
parser=parser,
)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.import_rag_files(request=request)
except Exception as e:
raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e
return response.result(timeout=timeout)
async def import_files_async(
corpus_name: str,
paths: Optional[Sequence[str]] = None,
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
transformation_config: Optional[TransformationConfig] = None,
max_embedding_requests_per_min: int = 1000,
import_result_sink: Optional[str] = None,
partial_failures_sink: Optional[str] = None,
parser: Optional[LayoutParserConfig] = None,
) -> operation_async.AsyncOperation:
"""
Import files to an existing RagCorpus asynchronously.
Example usage:
```
import vertexai
from vertexai import rag
from google.protobuf import timestamp_pb2
vertexai.init(project="my-project")
# Google Drive example
paths = [
"https://drive.google.com/file/d/123",
"https://drive.google.com/drive/folders/456"
]
# Google Cloud Storage example
paths = ["gs://my_bucket/my_files_dir", ...]
transformation_config = TransformationConfig(
chunking_config=ChunkingConfig(
chunk_size=1024,
chunk_overlap=200,
),
)
response = await rag.import_files_async(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
paths=paths,
transformation_config=transformation_config,
)
# Slack example
start_time = timestamp_pb2.Timestamp()
start_time.FromJsonString('2020-12-31T21:33:44Z')
end_time = timestamp_pb2.Timestamp()
end_time.GetCurrentTime()
source = rag.SlackChannelsSource(
channels = [
SlackChannel("channel1", "api_key1"),
SlackChannel("channel2", "api_key2", start_time, end_time)
],
)
# Jira Example
jira_query = rag.JiraQuery(
email="xxx@yyy.com",
jira_projects=["project1", "project2"],
custom_queries=["query1", "query2"],
api_key="api_key",
server_uri="server.atlassian.net"
)
source = rag.JiraSource(
queries=[jira_query],
)
response = await rag.import_files_async(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
source=source,
transformation_config=transformation_config,
)
# SharePoint Example.
sharepoint_query = rag.SharePointSource(
sharepoint_folder_path="https://my-sharepoint-site.com/my-folder",
sharepoint_site_name="my-sharepoint-site.com",
client_id="my-client-id",
client_secret="my-client-secret",
tenant_id="my-tenant-id",
drive_id="my-drive-id",
)
source = rag.SharePointSources(
share_point_sources=[sharepoint_query],
)
# Document AI Layout Parser example.
parser = LayoutParserConfig(
processor_name="projects/my-project/locations/us-central1/processors/my-processor-id",
max_parsing_requests_per_min=120,
)
response = rag.import_files_async(
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
paths=paths,
parser=parser,
)
# Get the result.
await response.result()
```
Args:
corpus_name: The name of the RagCorpus resource into which to import files.
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
paths: A list of uris. Eligible uris will be Google Cloud Storage
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
(https://drive.google.com/file/... or folder
"https://drive.google.com/corp/drive/folders/...").
source: The source of the Slack or Jira import.
Must be either a SlackChannelsSource or JiraSource.
transformation_config: The config for transforming the imported
RagFiles.
max_embedding_requests_per_min:
Optional. The max number of queries per
minute that this job is allowed to make to the
embedding model specified on the corpus. This
value is specific to this job and not shared
across other import jobs. Consult the Quotas
page on the project to set an appropriate value
here. If unspecified, a default value of 1,000
QPM would be used.
import_result_sink: Either a GCS path to store import results or a
BigQuery table to store import results. The format is
"gs://my-bucket/my/object.ndjson" for GCS or
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
object cannot be used. However, the BigQuery table may or may not
exist - if it does not exist, it will be created. If it does exist,
the schema will be checked and the import results will be appended
to the table.
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
Either a GCS path to store partial failures or a BigQuery table to
store partial failures. The format is
"gs://my-bucket/my/object.ndjson" for GCS or
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
object cannot be used. However, the BigQuery table may or may not
exist - if it does not exist, it will be created. If it does exist,
the schema will be checked and the partial failures will be appended
to the table.
parser: Document parser to use. Should be either None (default parser),
or a LayoutParserConfig (to parse documents using a Document AI
Layout Parser processor).
Returns:
operation_async.AsyncOperation.
"""
if source is not None and paths is not None:
raise ValueError("Only one of source or paths must be passed in at a time")
if source is None and paths is None:
raise ValueError("One of source or paths must be passed in")
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = _gapic_utils.prepare_import_files_request(
corpus_name=corpus_name,
paths=paths,
source=source,
transformation_config=transformation_config,
max_embedding_requests_per_min=max_embedding_requests_per_min,
import_result_sink=import_result_sink,
partial_failures_sink=partial_failures_sink,
parser=parser,
)
async_client = _gapic_utils.create_rag_data_service_async_client()
try:
response = await async_client.import_rag_files(request=request)
except Exception as e:
raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e
return response
def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile:
"""
Get an existing RagFile.
Args:
name: Either a full RagFile resource name must be provided, or a RagCorpus
name and a RagFile name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
or ``{rag_file}``.
corpus_name: If `name` is not a full resource name, an existing RagCorpus
name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
Returns:
RagFile.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
name = _gapic_utils.get_file_name(name, corpus_name)
request = GetRagFileRequest(name=name)
client = _gapic_utils.create_rag_data_service_client()
try:
response = client.get_rag_file(request=request)
except Exception as e:
raise RuntimeError("Failed in getting the RagFile due to: ", e) from e
return _gapic_utils.convert_gapic_to_rag_file(response)
def list_files(
corpus_name: str, page_size: Optional[int] = None, page_token: Optional[str] = None
) -> ListRagFilesPager:
"""
List all RagFiles in an existing RagCorpus.
Example usage:
```
import vertexai
vertexai.init(project="my-project")
# List all corpora.
rag_corpora = list(rag.list_corpora())
# List all files of the first corpus.
rag_files = list(rag.list_files(corpus_name=rag_corpora[0].name))
# Alternatively, return a ListRagFilesPager.
pager_1 = rag.list_files(
corpus_name=rag_corpora[0].name,
page_size=10
)
# Then get the next page, use the generated next_page_token from the last pager.
pager_2 = rag.list_files(
corpus_name=rag_corpora[0].name,
page_size=10,
page_token=pager_1.next_page_token
)
```
Args:
corpus_name: An existing RagCorpus name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
page_size: The standard list page size. Leaving out the page_size
causes all of the results to be returned.
page_token: The standard list page token.
Returns:
ListRagFilesPager.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
request = ListRagFilesRequest(
parent=corpus_name,
page_size=page_size,
page_token=page_token,
)
client = _gapic_utils.create_rag_data_service_client()
try:
pager = client.list_rag_files(request=request)
except Exception as e:
raise RuntimeError("Failed in listing the RagFiles due to: ", e) from e
return pager
def delete_file(name: str, corpus_name: Optional[str] = None) -> None:
"""
Delete RagFile from an existing RagCorpus.
Args:
name: Either a full RagFile resource name must be provided, or a RagCorpus
name and a RagFile name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
or ``{rag_file}``.
corpus_name: If `name` is not a full resource name, an existing RagCorpus
name must be provided. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
or ``{rag_corpus}``.
"""
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
name = _gapic_utils.get_file_name(name, corpus_name)
request = DeleteRagFileRequest(name=name)
client = _gapic_utils.create_rag_data_service_client()
try:
client.delete_rag_file(request=request)
print("Successfully deleted the RagFile.")
except Exception as e:
raise RuntimeError("Failed in RagFile deletion due to: ", e) from e
return None

View File

@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Retrieval query to get relevant contexts."""
import re
from typing import List, Optional
from google.cloud import aiplatform_v1
from google.cloud.aiplatform import initializer
from vertexai.rag.utils import _gapic_utils
from vertexai.rag.utils import resources
def retrieval_query(
text: str,
rag_resources: Optional[List[resources.RagResource]] = None,
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
) -> aiplatform_v1.RetrieveContextsResponse:
"""Retrieve top k relevant docs/chunks.
Example usage:
```
import vertexai
vertexai.init(project="my-project")
config = vertexai.rag.rag_retrieval_config(
top_k=2,
filter=vertexai.rag.rag_retrieval_config.filter(
vector_distance_threshold=0.5
),
ranking=vertex.rag.Ranking(
llm_ranker=vertexai.rag.LlmRanker(
model_name="gemini-1.5-flash-002"
)
)
)
results = vertexai.rag.retrieval_query(
text="Why is the sky blue?",
rag_resources=[vertexai.rag.RagResource(
rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1",
rag_file_ids=["rag-file-1", "rag-file-2", ...],
)],
rag_retrieval_config=config,
)
```
Args:
text: The query in text format to get relevant contexts.
rag_resources: A list of RagResource. It can be used to specify corpus
only or ragfiles. Currently only support one corpus or multiple files
from one corpus. In the future we may open up multiple corpora support.
rag_retrieval_config: Optional. The config containing the retrieval
parameters, including similarity_top_k and vector_distance_threshold
Returns:
RetrieveContextsResonse.
"""
parent = initializer.global_config.common_location_path()
client = _gapic_utils.create_rag_service_client()
if rag_resources:
if len(rag_resources) > 1:
raise ValueError("Currently only support 1 RagResource.")
name = rag_resources[0].rag_corpus
else:
raise ValueError("rag_resources must be specified.")
data_client = _gapic_utils.create_rag_data_service_client()
if data_client.parse_rag_corpus_path(name):
rag_corpus_name = name
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
rag_corpus_name = parent + "/ragCorpora/" + name
else:
raise ValueError(
f"Invalid RagCorpus name: {name}. Proper format should be:"
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
)
if rag_resources:
gapic_rag_resource = (
aiplatform_v1.RetrieveContextsRequest.VertexRagStore.RagResource(
rag_corpus=rag_corpus_name,
rag_file_ids=rag_resources[0].rag_file_ids,
)
)
vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore(
rag_resources=[gapic_rag_resource],
)
else:
vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore(
rag_corpora=[rag_corpus_name],
)
# If rag_retrieval_config is not specified, set it to default values.
if not rag_retrieval_config:
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
else:
# If rag_retrieval_config is specified, check for missing parameters.
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
api_retrieval_config.top_k = rag_retrieval_config.top_k
# Set vector_distance_threshold to config value if specified
if rag_retrieval_config.filter:
# Check if both vector_distance_threshold and vector_similarity_threshold
# are specified.
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
api_retrieval_config.filter.vector_distance_threshold = (
rag_retrieval_config.filter.vector_distance_threshold
)
api_retrieval_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)
if (
rag_retrieval_config.ranking
and rag_retrieval_config.ranking.rank_service
and rag_retrieval_config.ranking.llm_ranker
):
raise ValueError("Only one of rank_service and llm_ranker can be set.")
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
api_retrieval_config.ranking.rank_service.model_name = (
rag_retrieval_config.ranking.rank_service.model_name
)
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
api_retrieval_config.ranking.llm_ranker.model_name = (
rag_retrieval_config.ranking.llm_ranker.model_name
)
query = aiplatform_v1.RagQuery(
text=text,
rag_retrieval_config=api_retrieval_config,
)
request = aiplatform_v1.RetrieveContextsRequest(
vertex_rag_store=vertex_rag_store,
parent=parent,
query=query,
)
try:
response = client.retrieve_contexts(request=request)
except Exception as e:
raise RuntimeError("Failed in retrieving contexts due to: ", e) from e
return response

View File

@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""RAG retrieval tool for content generation."""
import re
from typing import List, Optional, Union
from google.cloud import aiplatform_v1beta1
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
from vertexai import generative_models
from vertexai.rag.utils import _gapic_utils
from vertexai.rag.utils import resources
class Retrieval(generative_models.grounding.Retrieval):
"""Defines a retrieval tool that a model can call to access external knowledge."""
def __init__(
self,
source: Union["VertexRagStore"],
disable_attribution: Optional[bool] = False,
):
self._raw_retrieval = gapic_tool_types.Retrieval(
vertex_rag_store=source._raw_vertex_rag_store,
disable_attribution=disable_attribution,
)
class VertexRagStore:
"""Retrieve from Vertex RAG Store."""
def __init__(
self,
rag_resources: Optional[List[resources.RagResource]] = None,
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
):
"""Initializes a Vertex RAG store tool.
Example usage:
```
import vertexai
vertexai.init(project="my-project")
config = vertexai.rag.RagRetrievalConfig(
top_k=2,
filter=vertexai.rag.RagRetrievalConfig.Filter(
vector_distance_threshold=0.5
),
ranking=vertex.rag.Ranking(
llm_ranker=vertexai.rag.LlmRanker(
model_name="gemini-1.5-flash-002"
)
)
)
tool = Tool.from_retrieval(
retrieval=vertexai.rag.Retrieval(
source=vertexai.rag.VertexRagStore(
rag_corpora=["projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1"],
rag_retrieval_config=config,
),
)
)
```
Args:
rag_resources: List of RagResource to retrieve from. It can be used
to specify corpus only or ragfiles. Currently only support one
corpus or multiple files from one corpus. In the future we
may open up multiple corpora support.
rag_retrieval_config: Optional. The config containing the retrieval
parameters, including similarity_top_k and vector_distance_threshold.
"""
if rag_resources:
if len(rag_resources) > 1:
raise ValueError("Currently only support 1 RagResource.")
name = rag_resources[0].rag_corpus
else:
raise ValueError("rag_resources must be specified.")
data_client = _gapic_utils.create_rag_data_service_client()
if data_client.parse_rag_corpus_path(name):
rag_corpus_name = name
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
parent = initializer.global_config.common_location_path()
rag_corpus_name = parent + "/ragCorpora/" + name
else:
raise ValueError(
f"Invalid RagCorpus name: {name}. Proper format should be:"
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
)
# If rag_retrieval_config is not specified, set it to default values.
api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig()
# If rag_retrieval_config is specified, populate the default config.
if rag_retrieval_config:
api_retrieval_config.top_k = rag_retrieval_config.top_k
# Set vector_distance_threshold to config value if specified
if rag_retrieval_config.filter:
# Check if both vector_distance_threshold and
# vector_similarity_threshold are specified.
if (
rag_retrieval_config.filter
and rag_retrieval_config.filter.vector_distance_threshold
and rag_retrieval_config.filter.vector_similarity_threshold
):
raise ValueError(
"Only one of vector_distance_threshold or"
" vector_similarity_threshold can be specified at a time"
" in rag_retrieval_config."
)
api_retrieval_config.filter.vector_distance_threshold = (
rag_retrieval_config.filter.vector_distance_threshold
)
api_retrieval_config.filter.vector_similarity_threshold = (
rag_retrieval_config.filter.vector_similarity_threshold
)
# Check if both rank_service and llm_ranker are specified.
if (
rag_retrieval_config.ranking
and rag_retrieval_config.ranking.rank_service
and rag_retrieval_config.ranking.rank_service.model_name
and rag_retrieval_config.ranking.llm_ranker
and rag_retrieval_config.ranking.llm_ranker.model_name
):
raise ValueError(
"Only one of rank_service or llm_ranker can be specified"
" at a time in rag_retrieval_config."
)
# Set rank_service to config value if specified
if (
rag_retrieval_config.ranking
and rag_retrieval_config.ranking.rank_service
):
api_retrieval_config.ranking.rank_service.model_name = (
rag_retrieval_config.ranking.rank_service.model_name
)
# Set llm_ranker to config value if specified
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
api_retrieval_config.ranking.llm_ranker.model_name = (
rag_retrieval_config.ranking.llm_ranker.model_name
)
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
rag_corpus=rag_corpus_name,
rag_file_ids=rag_resources[0].rag_file_ids,
)
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
rag_resources=[gapic_rag_resource],
rag_retrieval_config=api_retrieval_config,
)

View File

@@ -0,0 +1,656 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Any, Dict, Optional, Sequence, Union
from google.cloud.aiplatform_v1.types import api_auth
from google.cloud.aiplatform_v1 import (
RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig,
GoogleDriveSource,
ImportRagFilesConfig,
ImportRagFilesRequest,
RagFileChunkingConfig,
RagFileParsingConfig,
RagFileTransformationConfig,
RagCorpus as GapicRagCorpus,
RagFile as GapicRagFile,
SharePointSources as GapicSharePointSources,
SlackSource as GapicSlackSource,
JiraSource as GapicJiraSource,
RagVectorDbConfig as GapicRagVectorDbConfig,
VertexAiSearchConfig as GapicVertexAiSearchConfig,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import (
VertexRagDataAsyncClientWithOverride,
VertexRagDataClientWithOverride,
VertexRagClientWithOverride,
)
from vertexai.rag.utils.resources import (
LayoutParserConfig,
Pinecone,
RagCorpus,
RagEmbeddingModelConfig,
RagFile,
RagManagedDb,
RagVectorDbConfig,
SharePointSources,
SlackChannelsSource,
TransformationConfig,
JiraSource,
VertexAiSearchConfig,
VertexVectorSearch,
VertexPredictionEndpoint,
)
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = (
r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?"
)
def create_rag_data_service_client():
return initializer.global_config.create_client(
client_class=VertexRagDataClientWithOverride,
).select_version("v1")
def create_rag_data_service_async_client():
return initializer.global_config.create_client(
client_class=VertexRagDataAsyncClientWithOverride,
).select_version("v1")
def create_rag_service_client():
return initializer.global_config.create_client(
client_class=VertexRagClientWithOverride,
).select_version("v1")
def convert_gapic_to_rag_embedding_model_config(
gapic_embedding_model_config: GapicRagEmbeddingModelConfig,
) -> RagEmbeddingModelConfig:
"""Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig."""
embedding_model_config = RagEmbeddingModelConfig()
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
publisher_model = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
path,
)
endpoint = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
path,
)
if publisher_model:
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
publisher_model=path
)
if endpoint:
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
endpoint=path,
model=gapic_embedding_model_config.vertex_prediction_endpoint.model,
model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id,
)
return embedding_model_config
def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("weaviate")
except AttributeError:
return gapic_vector_db.weaviate.ByteSize() > 0
def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("rag_managed_db")
except AttributeError:
return gapic_vector_db.rag_managed_db.ByteSize() > 0
def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("vertex_feature_store")
except AttributeError:
return gapic_vector_db.vertex_feature_store.ByteSize() > 0
def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("pinecone")
except AttributeError:
return gapic_vector_db.pinecone.ByteSize() > 0
def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
try:
return gapic_vector_db.__contains__("vertex_vector_search")
except AttributeError:
return gapic_vector_db.vertex_vector_search.ByteSize() > 0
def _check_rag_embedding_model_config(
gapic_vector_db: GapicRagVectorDbConfig,
) -> bool:
try:
return gapic_vector_db.__contains__("rag_embedding_model_config")
except AttributeError:
return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0
def convert_gapic_to_backend_config(
gapic_vector_db: GapicRagVectorDbConfig,
) -> RagVectorDbConfig:
"""Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb."""
vector_config = RagVectorDbConfig()
if _check_pinecone(gapic_vector_db):
vector_config.vector_db = Pinecone(
index_name=gapic_vector_db.pinecone.index_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
elif _check_vertex_vector_search(gapic_vector_db):
vector_config.vector_db = VertexVectorSearch(
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
index=gapic_vector_db.vertex_vector_search.index,
)
elif _check_rag_managed_db(gapic_vector_db):
vector_config.vector_db = RagManagedDb()
if _check_rag_embedding_model_config(gapic_vector_db):
vector_config.rag_embedding_model_config = (
convert_gapic_to_rag_embedding_model_config(
gapic_vector_db.rag_embedding_model_config
)
)
return vector_config
def convert_gapic_to_vertex_ai_search_config(
gapic_vertex_ai_search_config: VertexAiSearchConfig,
) -> VertexAiSearchConfig:
"""Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig."""
if gapic_vertex_ai_search_config.serving_config:
return VertexAiSearchConfig(
serving_config=gapic_vertex_ai_search_config.serving_config,
)
return None
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
"""Convert GapicRagCorpus to RagCorpus."""
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
description=gapic_rag_corpus.description,
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
gapic_rag_corpus.vertex_ai_search_config
),
backend_config=convert_gapic_to_backend_config(
gapic_rag_corpus.vector_db_config
),
)
return rag_corpus
def convert_gapic_to_rag_corpus_no_embedding_model_config(
gapic_rag_corpus: GapicRagCorpus,
) -> RagCorpus:
"""Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus."""
rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config
rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
description=gapic_rag_corpus.description,
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
gapic_rag_corpus.vertex_ai_search_config
),
backend_config=convert_gapic_to_backend_config(
rag_vector_db_config_no_embedding_model_config
),
)
return rag_corpus
def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
"""Convert GapicRagFile to RagFile."""
rag_file = RagFile(
name=gapic_rag_file.name,
display_name=gapic_rag_file.display_name,
description=gapic_rag_file.description,
)
return rag_file
def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile:
"""Converts a JSON response to a RagFile."""
rag_file = RagFile(
name=upload_rag_file_response.get("ragFile").get("name"),
display_name=upload_rag_file_response.get("ragFile").get("displayName"),
description=upload_rag_file_response.get("ragFile").get("description"),
)
return rag_file
def convert_path_to_resource_id(
path: str,
) -> Union[str, GoogleDriveSource.ResourceId]:
"""Converts a path to a Google Cloud storage uri or GoogleDriveSource.ResourceId."""
if path.startswith("gs://"):
# Google Cloud Storage source
return path
elif path.startswith("https://drive.google.com/"):
# Google Drive source
path_list = path.split("/")
if "file" in path_list:
index = path_list.index("file") + 2
resource_id = path_list[index].split("?")[0]
resource_type = GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE
elif "folders" in path_list:
index = path_list.index("folders") + 1
resource_id = path_list[index].split("?")[0]
resource_type = (
GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER
)
else:
raise ValueError("path %s is not a valid Google Drive url.", path)
return GoogleDriveSource.ResourceId(
resource_id=resource_id,
resource_type=resource_type,
)
else:
raise ValueError(
"path must be a Google Cloud Storage uri or a Google Drive url."
)
def convert_source_for_rag_import(
source: Union[SlackChannelsSource, JiraSource, SharePointSources]
) -> Union[GapicSlackSource, GapicJiraSource]:
"""Converts a SlackChannelsSource or JiraSource to a GapicSlackSource or GapicJiraSource."""
if isinstance(source, SlackChannelsSource):
result_source_channels = []
for channel in source.channels:
api_key = channel.api_key
cid = channel.channel_id
start_time = channel.start_time
end_time = channel.end_time
result_channels = GapicSlackSource.SlackChannels(
channels=[
GapicSlackSource.SlackChannels.SlackChannel(
channel_id=cid,
start_time=start_time,
end_time=end_time,
)
],
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=api_key
),
)
result_source_channels.append(result_channels)
return GapicSlackSource(
channels=result_source_channels,
)
elif isinstance(source, JiraSource):
result_source_queries = []
for query in source.queries:
api_key = query.api_key
custom_queries = query.custom_queries
projects = query.jira_projects
email = query.email
server_uri = query.server_uri
result_query = GapicJiraSource.JiraQueries(
custom_queries=custom_queries,
projects=projects,
email=email,
server_uri=server_uri,
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=api_key
),
)
result_source_queries.append(result_query)
return GapicJiraSource(
jira_queries=result_source_queries,
)
elif isinstance(source, SharePointSources):
result_source_share_point_sources = []
for share_point_source in source.share_point_sources:
sharepoint_folder_path = share_point_source.sharepoint_folder_path
sharepoint_folder_id = share_point_source.sharepoint_folder_id
drive_name = share_point_source.drive_name
drive_id = share_point_source.drive_id
client_id = share_point_source.client_id
client_secret = share_point_source.client_secret
tenant_id = share_point_source.tenant_id
sharepoint_site_name = share_point_source.sharepoint_site_name
result_share_point_source = GapicSharePointSources.SharePointSource(
client_id=client_id,
client_secret=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=client_secret
),
tenant_id=tenant_id,
sharepoint_site_name=sharepoint_site_name,
)
if sharepoint_folder_path is not None and sharepoint_folder_id is not None:
raise ValueError(
"sharepoint_folder_path and sharepoint_folder_id cannot both be set."
)
elif sharepoint_folder_path is not None:
result_share_point_source.sharepoint_folder_path = (
sharepoint_folder_path
)
elif sharepoint_folder_id is not None:
result_share_point_source.sharepoint_folder_id = sharepoint_folder_id
if drive_name is not None and drive_id is not None:
raise ValueError("drive_name and drive_id cannot both be set.")
elif drive_name is not None:
result_share_point_source.drive_name = drive_name
elif drive_id is not None:
result_share_point_source.drive_id = drive_id
else:
raise ValueError("Either drive_name and drive_id must be set.")
result_source_share_point_sources.append(result_share_point_source)
return GapicSharePointSources(
share_point_sources=result_source_share_point_sources,
)
else:
raise TypeError(
"source must be a SlackChannelsSource or JiraSource or SharePointSources."
)
def prepare_import_files_request(
corpus_name: str,
paths: Optional[Sequence[str]] = None,
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
transformation_config: Optional[TransformationConfig] = None,
max_embedding_requests_per_min: int = 1000,
import_result_sink: Optional[str] = None,
partial_failures_sink: Optional[str] = None,
parser: Optional[LayoutParserConfig] = None,
) -> ImportRagFilesRequest:
if len(corpus_name.split("/")) != 6:
raise ValueError(
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
)
rag_file_parsing_config = RagFileParsingConfig()
if parser is not None:
if (
re.fullmatch(_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, parser.processor_name)
is None
):
raise ValueError(
"processor_name must be of the format "
"`projects/{project_id}/locations/{location}/processors/{processor_id}`"
"or "
"`projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`, "
f"got {parser.processor_name!r}"
)
rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser(
processor_name=parser.processor_name,
max_parsing_requests_per_min=parser.max_parsing_requests_per_min,
)
chunk_size = 1024
chunk_overlap = 200
if transformation_config and transformation_config.chunking_config:
chunk_size = transformation_config.chunking_config.chunk_size
chunk_overlap = transformation_config.chunking_config.chunk_overlap
rag_file_transformation_config = RagFileTransformationConfig(
rag_file_chunking_config=RagFileChunkingConfig(
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
),
),
)
import_rag_files_config = ImportRagFilesConfig(
rag_file_transformation_config=rag_file_transformation_config,
rag_file_parsing_config=rag_file_parsing_config,
max_embedding_requests_per_min=max_embedding_requests_per_min,
)
import_result_sink = import_result_sink or partial_failures_sink
if import_result_sink is not None:
if import_result_sink.startswith("gs://"):
import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = (
import_result_sink
)
elif import_result_sink.startswith("bq://"):
import_rag_files_config.partial_failure_bigquery_sink.output_uri = (
import_result_sink
)
else:
raise ValueError(
"import_result_sink must be a GCS path or a BigQuery table."
)
if source is not None:
gapic_source = convert_source_for_rag_import(source)
if isinstance(gapic_source, GapicSlackSource):
import_rag_files_config.slack_source = gapic_source
if isinstance(gapic_source, GapicJiraSource):
import_rag_files_config.jira_source = gapic_source
if isinstance(gapic_source, GapicSharePointSources):
import_rag_files_config.share_point_sources = gapic_source
else:
uris = []
resource_ids = []
for p in paths:
output = convert_path_to_resource_id(p)
if isinstance(output, str):
uris.append(p)
else:
resource_ids.append(output)
if uris:
import_rag_files_config.gcs_source.uris = uris
if resource_ids:
google_drive_source = GoogleDriveSource(
resource_ids=resource_ids,
)
import_rag_files_config.google_drive_source = google_drive_source
request = ImportRagFilesRequest(
parent=corpus_name, import_rag_files_config=import_rag_files_config
)
return request
def get_corpus_name(
name: str,
) -> str:
if name:
client = create_rag_data_service_client()
if client.parse_rag_corpus_path(name):
return name
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
return client.rag_corpus_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
rag_corpus=name,
)
else:
raise ValueError(
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`"
)
return name
def get_file_name(
name: str,
corpus_name: str,
) -> str:
client = create_rag_data_service_client()
if client.parse_rag_file_path(name):
return name
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
if not corpus_name:
raise ValueError(
"corpus_name must be provided if name is a `{rag_file}`, not a "
"full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). "
)
return client.rag_file_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
rag_corpus=get_corpus_name(corpus_name),
rag_file=name,
)
else:
raise ValueError(
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
)
def set_embedding_model_config(
embedding_model_config: RagEmbeddingModelConfig,
rag_corpus: GapicRagCorpus,
) -> None:
if embedding_model_config.vertex_prediction_endpoint is None:
return
if (
embedding_model_config.vertex_prediction_endpoint.publisher_model
and embedding_model_config.vertex_prediction_endpoint.endpoint
):
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
if (
not embedding_model_config.vertex_prediction_endpoint.publisher_model
and not embedding_model_config.vertex_prediction_endpoint.endpoint
):
raise ValueError("At least one of publisher_model and endpoint must be set.")
parent = initializer.global_config.common_location_path(project=None, location=None)
if embedding_model_config.vertex_prediction_endpoint.publisher_model:
publisher_model = (
embedding_model_config.vertex_prediction_endpoint.publisher_model
)
full_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
publisher_model,
)
resource_name = re.match(
r"^publishers/google/models/(?P<model_id>.+?)$",
publisher_model,
)
if full_resource_name:
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
publisher_model
)
elif resource_name:
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
parent + "/" + publisher_model
)
else:
raise ValueError(
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
)
if embedding_model_config.vertex_prediction_endpoint.endpoint:
endpoint = embedding_model_config.vertex_prediction_endpoint.endpoint
full_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
endpoint,
)
resource_name = re.match(
r"^endpoints/(?P<endpoint>.+?)$",
endpoint,
)
if full_resource_name:
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
endpoint
)
elif resource_name:
rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
parent + "/" + endpoint
)
else:
raise ValueError(
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
)
def set_backend_config(
backend_config: Optional[
Union[
RagVectorDbConfig,
None,
]
],
rag_corpus: GapicRagCorpus,
) -> None:
"""Sets the vector db configuration for the rag corpus."""
if backend_config is None:
return
if backend_config.vector_db is not None:
vector_config = backend_config.vector_db
if vector_config is None or isinstance(vector_config, RagManagedDb):
rag_corpus.vector_db_config.rag_managed_db.CopyFrom(
GapicRagVectorDbConfig.RagManagedDb()
)
elif isinstance(vector_config, VertexVectorSearch):
index_endpoint = vector_config.index_endpoint
index = vector_config.index
rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = (
index_endpoint
)
rag_corpus.vector_db_config.vertex_vector_search.index = index
elif isinstance(vector_config, Pinecone):
index_name = vector_config.index_name
api_key = vector_config.api_key
rag_corpus.vector_db_config.pinecone.index_name = index_name
rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = (
api_key
)
else:
raise TypeError(
"backend_config must be a VertexFeatureStore,"
"RagManagedDb, or Pinecone."
)
if backend_config.rag_embedding_model_config:
set_embedding_model_config(
backend_config.rag_embedding_model_config, rag_corpus
)
def set_vertex_ai_search_config(
vertex_ai_search_config: VertexAiSearchConfig,
rag_corpus: GapicRagCorpus,
) -> None:
if not vertex_ai_search_config.serving_config:
raise ValueError("serving_config must be set.")
engine_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/engines/(?P<engine>.+?)/servingConfigs/(?P<serving_config>.+?)$",
vertex_ai_search_config.serving_config,
)
data_store_resource_name = re.match(
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/dataStores/(?P<data_store>.+?)/servingConfigs/(?P<serving_config>.+?)$",
vertex_ai_search_config.serving_config,
)
if engine_resource_name or data_store_resource_name:
rag_corpus.vertex_ai_search_config = GapicVertexAiSearchConfig(
serving_config=vertex_ai_search_config.serving_config,
)
else:
raise ValueError(
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
)

View File

@@ -0,0 +1,447 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
from typing import List, Optional, Sequence, Union
from google.protobuf import timestamp_pb2
@dataclasses.dataclass
class RagFile:
"""RAG file (output only).
Attributes:
name: Generated resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}/ragFiles/{rag_file}``
display_name: Display name that was configured at client side.
description: The description of the RagFile.
"""
name: Optional[str] = None
display_name: Optional[str] = None
description: Optional[str] = None
@dataclasses.dataclass
class VertexPredictionEndpoint:
"""VertexPredictionEndpoint.
Attributes:
publisher_model: 1P publisher model resource name. Format:
``publishers/google/models/{model}`` or
``projects/{project}/locations/{location}/publishers/google/models/{model}``
endpoint: 1P fine tuned embedding model resource name. Format:
``endpoints/{endpoint}`` or
``projects/{project}/locations/{location}/endpoints/{endpoint}``.
model:
Output only. The resource name of the model that is deployed
on the endpoint. Present only when the endpoint is not a
publisher model. Pattern:
``projects/{project}/locations/{location}/models/{model}``
model_version_id:
Output only. Version ID of the model that is
deployed on the endpoint. Present only when the
endpoint is not a publisher model.
"""
endpoint: Optional[str] = None
publisher_model: Optional[str] = None
model: Optional[str] = None
model_version_id: Optional[str] = None
@dataclasses.dataclass
class RagEmbeddingModelConfig:
"""RagEmbeddingModelConfig.
Attributes:
vertex_prediction_endpoint: The Vertex AI Prediction Endpoint resource
name. Format:
``projects/{project}/locations/{location}/endpoints/{endpoint}``
"""
vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None
@dataclasses.dataclass
class Weaviate:
"""Weaviate.
Attributes:
weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint
collection_name: The corresponding Weaviate collection this corpus maps to
api_key: The SecretManager resource name for the Weaviate DB API token. Format:
``projects/{project}/secrets/{secret}/versions/{version}``
"""
weaviate_http_endpoint: Optional[str] = None
collection_name: Optional[str] = None
api_key: Optional[str] = None
@dataclasses.dataclass
class VertexFeatureStore:
"""VertexFeatureStore.
Attributes:
resource_name: The resource name of the FeatureView. Format:
``projects/{project}/locations/{location}/featureOnlineStores/
{feature_online_store}/featureViews/{feature_view}``
"""
resource_name: Optional[str] = None
@dataclasses.dataclass
class VertexVectorSearch:
"""VertexVectorSearch.
Attributes:
index_endpoint (str):
The resource name of the Index Endpoint. Format:
``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}``
index (str):
The resource name of the Index. Format:
``projects/{project}/locations/{location}/indexes/{index}``
"""
index_endpoint: Optional[str] = None
index: Optional[str] = None
@dataclasses.dataclass
class RagManagedDb:
"""RagManagedDb."""
@dataclasses.dataclass
class Pinecone:
"""Pinecone.
Attributes:
index_name: The Pinecone index name.
api_key: The SecretManager resource name for the Pinecone DB API token. Format:
``projects/{project}/secrets/{secret}/versions/{version}``
"""
index_name: Optional[str] = None
api_key: Optional[str] = None
@dataclasses.dataclass
class VertexAiSearchConfig:
"""VertexAiSearchConfig.
Attributes:
serving_config: The resource name of the Vertex AI Search serving config.
Format:
``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}``
or
``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}``
"""
serving_config: Optional[str] = None
@dataclasses.dataclass
class RagVectorDbConfig:
"""RagVectorDbConfig.
Attributes:
vector_db: Can be one of the following: RagManagedDb, Pinecone,
VertexVectorSearch.
rag_embedding_model_config: The embedding model config of the Vector DB.
"""
vector_db: Optional[
Union[
VertexVectorSearch,
Pinecone,
RagManagedDb,
]
] = None
rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None
@dataclasses.dataclass
class RagCorpus:
"""RAG corpus(output only).
Attributes:
name: Generated resource name. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
display_name: Display name that was configured at client side.
description: The description of the RagCorpus.
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
backend_config: The backend config of the RagCorpus. It can be a data
store and/or retrieval engine.
"""
name: Optional[str] = None
display_name: Optional[str] = None
description: Optional[str] = None
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None
backend_config: Optional[
Union[
RagVectorDbConfig,
None,
]
] = None
@dataclasses.dataclass
class RagResource:
"""RagResource.
The representation of the rag source. It can be used to specify corpus only
or ragfiles. Currently only support one corpus or multiple files from one
corpus. In the future we may open up multiple corpora support.
Attributes:
rag_corpus: A Rag corpus resource name or corpus id. Format:
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
or ``{rag_corpus_id}``.
rag_files_id: List of Rag file resource name or file ids in the same corpus. Format:
``{rag_file}``.
"""
rag_corpus: Optional[str] = None
rag_file_ids: Optional[List[str]] = None
@dataclasses.dataclass
class SlackChannel:
"""SlackChannel.
Attributes:
channel_id: The Slack channel ID.
api_key: The SecretManager resource name for the Slack API token. Format:
``projects/{project}/secrets/{secret}/versions/{version}``
See: https://api.slack.com/tutorials/tracks/getting-a-token.
start_time: The starting timestamp for messages to import.
end_time: The ending timestamp for messages to import.
"""
channel_id: str
api_key: str
start_time: Optional[timestamp_pb2.Timestamp] = None
end_time: Optional[timestamp_pb2.Timestamp] = None
@dataclasses.dataclass
class SlackChannelsSource:
"""SlackChannelsSource.
Attributes:
channels: The Slack channels.
"""
channels: Sequence[SlackChannel]
@dataclasses.dataclass
class JiraQuery:
"""JiraQuery.
Attributes:
email: The Jira email address.
jira_projects: A list of Jira projects to import in their entirety.
custom_queries: A list of custom JQL Jira queries to import.
api_key: The SecretManager version resource name for Jira API access. Format:
``projects/{project}/secrets/{secret}/versions/{version}``
See: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/
server_uri: The Jira server URI. Format:
``{server}.atlassian.net``
"""
email: str
jira_projects: Sequence[str]
custom_queries: Sequence[str]
api_key: str
server_uri: str
@dataclasses.dataclass
class JiraSource:
"""JiraSource.
Attributes:
queries: The Jira queries.
"""
queries: Sequence[JiraQuery]
@dataclasses.dataclass
class SharePointSource:
"""SharePointSource.
Attributes:
sharepoint_folder_path: The path of the SharePoint folder to download
from.
sharepoint_folder_id: The ID of the SharePoint folder to download
from.
drive_name: The name of the drive to download from.
drive_id: The ID of the drive to download from.
client_id: The Application ID for the app registered in
Microsoft Azure Portal. The application must
also be configured with MS Graph permissions
"Files.ReadAll", "Sites.ReadAll" and
BrowserSiteLists.Read.All.
client_secret: The application secret for the app registered
in Azure.
tenant_id: Unique identifier of the Azure Active
Directory Instance.
sharepoint_site_name: The name of the SharePoint site to download
from. This can be the site name or the site id.
"""
sharepoint_folder_path: Optional[str] = None
sharepoint_folder_id: Optional[str] = None
drive_name: Optional[str] = None
drive_id: Optional[str] = None
client_id: str = None
client_secret: str = None
tenant_id: str = None
sharepoint_site_name: str = None
@dataclasses.dataclass
class SharePointSources:
"""SharePointSources.
Attributes:
share_point_sources: The SharePoint sources.
"""
share_point_sources: Sequence[SharePointSource]
@dataclasses.dataclass
class Filter:
"""Filter.
Attributes:
vector_distance_threshold: Only returns contexts with vector
distance smaller than the threshold.
vector_similarity_threshold: Only returns contexts with vector
similarity larger than the threshold.
metadata_filter: String for metadata filtering.
"""
vector_distance_threshold: Optional[float] = None
vector_similarity_threshold: Optional[float] = None
metadata_filter: Optional[str] = None
@dataclasses.dataclass
class LlmRanker:
"""LlmRanker.
Attributes:
model_name: The model name used for ranking. Only Gemini models are
supported for now.
"""
model_name: Optional[str] = None
@dataclasses.dataclass
class RankService:
"""RankService.
Attributes:
model_name: The model name of the rank service. Format:
``semantic-ranker-512@latest``
"""
model_name: Optional[str] = None
@dataclasses.dataclass
class Ranking:
"""Ranking.
Attributes:
rank_service: Config for Rank Service.
llm_ranker: Config for LlmRanker.
"""
rank_service: Optional[RankService] = None
llm_ranker: Optional[LlmRanker] = None
@dataclasses.dataclass
class RagRetrievalConfig:
"""RagRetrievalConfig.
Attributes:
top_k: The number of contexts to retrieve.
filter: Config for filters.
ranking: Config for ranking.
"""
top_k: Optional[int] = None
filter: Optional[Filter] = None
ranking: Optional[Ranking] = None
@dataclasses.dataclass
class ChunkingConfig:
"""ChunkingConfig.
Attributes:
chunk_size: The size of each chunk.
chunk_overlap: The size of the overlap between chunks.
"""
chunk_size: int
chunk_overlap: int
@dataclasses.dataclass
class TransformationConfig:
"""TransformationConfig.
Attributes:
chunking_config: The chunking config.
"""
chunking_config: Optional[ChunkingConfig] = None
@dataclasses.dataclass
class LayoutParserConfig:
"""Configuration for the Document AI Layout Parser Processor.
Attributes:
processor_name: The full resource name of a Document AI processor or
processor version. The processor must have type
`LAYOUT_PARSER_PROCESSOR`.
Format must be one of the following:
- `projects/{project_id}/locations/{location}/processors/{processor_id}`
- `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`
max_parsing_requests_per_min: The maximum number of requests the job is
allowed to make to the Document AI processor per minute. Consult
https://cloud.google.com/document-ai/quotas and the Quota page for
your project to set an appropriate value here. If unspecified, a
default value of 120 QPM will be used.
"""
processor_name: str
max_parsing_requests_per_min: Optional[int] = None