structure saas with tools
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.metadata import metadata
|
||||
|
||||
|
||||
# For Vertex AI Experiment.
|
||||
|
||||
# ExperimentRun manipulation.
|
||||
start_run = metadata._experiment_tracker.start_run
|
||||
end_run = metadata._experiment_tracker.end_run
|
||||
get_experiment_df = metadata._experiment_tracker.get_experiment_df
|
||||
|
||||
# Experiment logging.
|
||||
log_params = metadata._experiment_tracker.log_params
|
||||
log_metrics = metadata._experiment_tracker.log_metrics
|
||||
log_time_series_metrics = metadata._experiment_tracker.log_time_series_metrics
|
||||
log_classification_metrics = metadata._experiment_tracker.log_classification_metrics
|
||||
|
||||
|
||||
__all__ = (
|
||||
"start_run",
|
||||
"end_run",
|
||||
"get_experiment_df",
|
||||
"log_params",
|
||||
"log_metrics",
|
||||
"log_time_series_metrics",
|
||||
"log_classification_metrics",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,25 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for batch prediction."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.batch_prediction._batch_prediction import (
|
||||
BatchPredictionJob,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatchPredictionJob",
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vertexai.caching._caching import CachedContent
|
||||
|
||||
__all__ = [
|
||||
"CachedContent",
|
||||
]
|
||||
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Vertex Gen AI Evaluation Service Module."""
|
||||
|
||||
from vertexai.preview.evaluation import _base
|
||||
from vertexai.preview.evaluation import autorater_utils
|
||||
from vertexai.preview.evaluation import eval_task
|
||||
from vertexai.preview.evaluation import metrics
|
||||
from vertexai.preview.evaluation import prompt_template
|
||||
|
||||
|
||||
EvalResult = _base.EvalResult
|
||||
EvalTask = eval_task.EvalTask
|
||||
PairwiseMetric = metrics.PairwiseMetric
|
||||
PointwiseMetric = metrics.PointwiseMetric
|
||||
CustomMetric = metrics.CustomMetric
|
||||
PromptTemplate = prompt_template.PromptTemplate
|
||||
PairwiseMetricPromptTemplate = metrics.PairwiseMetricPromptTemplate
|
||||
PointwiseMetricPromptTemplate = metrics.PointwiseMetricPromptTemplate
|
||||
MetricPromptTemplateExamples = metrics.MetricPromptTemplateExamples
|
||||
AutoraterConfig = autorater_utils.AutoraterConfig
|
||||
CustomOutputConfig = metrics.CustomOutputConfig
|
||||
RubricBasedMetric = metrics.RubricBasedMetric
|
||||
RubricGenerationConfig = metrics.RubricGenerationConfig
|
||||
PredefinedRubricMetrics = metrics.PredefinedRubricMetrics
|
||||
|
||||
__all__ = [
|
||||
"EvalTask",
|
||||
"EvalResult",
|
||||
"PairwiseMetric",
|
||||
"PointwiseMetric",
|
||||
"CustomMetric",
|
||||
"PromptTemplate",
|
||||
"PairwiseMetricPromptTemplate",
|
||||
"PointwiseMetricPromptTemplate",
|
||||
"MetricPromptTemplateExamples",
|
||||
"AutoraterConfig",
|
||||
"CustomOutputConfig",
|
||||
"RubricBasedMetric",
|
||||
"RubricGenerationConfig",
|
||||
"PredefinedRubricMetrics",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Base classes for evaluation."""
|
||||
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.services import (
|
||||
evaluation_service as gapic_evaluation_services,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_base as metrics_base,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvaluationRunConfig:
|
||||
"""Evaluation Run Configurations.
|
||||
|
||||
Attributes:
|
||||
dataset: The dataset to evaluate.
|
||||
metrics: The list of metric names, or Metric instances to evaluate.
|
||||
metric_column_mapping: An optional dictionary column mapping that overrides
|
||||
the metric prompt template input variable names with mapped the evaluation
|
||||
dataset column names, used during evaluation. For example, if the
|
||||
input_variables of the metric prompt template are ["context",
|
||||
"reference"], the metric_column_mapping can be { "context":
|
||||
"news_context", "reference": "ground_truth", "response":
|
||||
"model_1_response" } if the dataset has columns "news_context",
|
||||
"ground_truth" and "model_1_response".
|
||||
client: The evaluation service client.
|
||||
evaluation_service_qps: The custom QPS limit for the evaluation service.
|
||||
retry_timeout: How long to keep retrying the evaluation requests, in
|
||||
seconds.
|
||||
autorater_config: The autorater config for model based evaluation.
|
||||
"""
|
||||
|
||||
dataset: "pd.DataFrame"
|
||||
metrics: List[Union[str, metrics_base._Metric]]
|
||||
metric_column_mapping: Dict[str, str]
|
||||
client: gapic_evaluation_services.EvaluationServiceClient
|
||||
evaluation_service_qps: float
|
||||
retry_timeout: float
|
||||
autorater_config: Optional[AutoraterConfig] = None
|
||||
|
||||
def validate_dataset_column(self, column_name: str) -> None:
|
||||
"""Validates that the column names in the column map are in the dataset.
|
||||
|
||||
Args:
|
||||
column_name: The column name to validate.
|
||||
|
||||
Raises:
|
||||
KeyError: If any of the column names are not in the dataset.
|
||||
"""
|
||||
if (
|
||||
self.metric_column_mapping.get(column_name, column_name)
|
||||
not in self.dataset.columns
|
||||
):
|
||||
raise KeyError(
|
||||
"Required column"
|
||||
f" `{self.metric_column_mapping.get(column_name, column_name)}` not"
|
||||
" found in the evaluation dataset. The columns in the evaluation"
|
||||
f" dataset are {list(self.dataset.columns)}."
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvalResult:
|
||||
"""Evaluation result.
|
||||
|
||||
Attributes:
|
||||
summary_metrics: A dictionary of summary evaluation metrics for an
|
||||
evaluation run.
|
||||
metrics_table: A pandas.DataFrame table containing evaluation dataset
|
||||
inputs, predictions, explanations, and metric results per row.
|
||||
metadata: The metadata for the evaluation run.
|
||||
"""
|
||||
|
||||
summary_metrics: Dict[str, float]
|
||||
metrics_table: Optional["pd.DataFrame"] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutoraterEvalResult:
|
||||
"""Evaluation result for autorater evaluation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eval_result: Optional[List[Dict[str, Any]]],
|
||||
eval_dataset_metadata: Optional[Dict[str, Any]],
|
||||
autorater_config: Optional[AutoraterConfig],
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes an AutoraterEvalResult.
|
||||
|
||||
Args:
|
||||
eval_result: Evaluation result from an evaluation run.
|
||||
eval_dataset_metadata: Evaluation dataset metadata.
|
||||
autorater_config: Autorater configuration.
|
||||
**kwargs: Additional arguments added to AutoraterEvalResult.
|
||||
"""
|
||||
self.eval_result = eval_result
|
||||
self.eval_dataset_metadata = eval_dataset_metadata
|
||||
self.autorater_config = autorater_config
|
||||
self.__dict__.update(kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Utility functions for all pre-evaluation steps."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent import futures
|
||||
from typing import Callable, Optional, Set, TYPE_CHECKING, Union, List
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
content as gapic_content_types,
|
||||
)
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview.evaluation import _base as evaluation_base
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation import multimodal_utils
|
||||
from vertexai.preview.evaluation import (
|
||||
prompt_template as prompt_template_base,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'tqdm is not installed. Please install the SDK using "pip install'
|
||||
' google-cloud-aiplatform[evaluation]"'
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_SUCCESSFUL_FINISH_REASONS = [
|
||||
gapic_content_types.Candidate.FinishReason.STOP,
|
||||
gapic_content_types.Candidate.FinishReason.MAX_TOKENS,
|
||||
# Many responses have this finish reason
|
||||
gapic_content_types.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
|
||||
]
|
||||
|
||||
|
||||
def _assemble_prompt(
|
||||
row: "pd.Series",
|
||||
prompt_template: Union[prompt_template_base.PromptTemplate, str],
|
||||
) -> str:
|
||||
"""Assembles the prompt template with the given row data."""
|
||||
if isinstance(prompt_template, str):
|
||||
prompt_template = prompt_template_base.PromptTemplate(prompt_template)
|
||||
_check_variable_columns_exist(row, prompt_template.variables)
|
||||
return str(
|
||||
prompt_template.assemble(
|
||||
**row[list(prompt_template.variables)].astype(str).to_dict()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _generate_content_text_response(
|
||||
model: generative_models.GenerativeModel, prompt: str, max_attempts: int = 3
|
||||
) -> str:
|
||||
"""Generates a text response from Gemini model from a text prompt with retries .
|
||||
|
||||
Args:
|
||||
model: The Gemini model instance.
|
||||
prompt: The prompt to send to the model.
|
||||
max_attempts: Maximum number of attempts for response generation.
|
||||
|
||||
Returns:
|
||||
The text response from the model.
|
||||
|
||||
Raises:
|
||||
RuntimeError if the prompt or the response for the prompt is blocked for
|
||||
safety reasons.
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = model.generate_content(prompt)
|
||||
if not response.candidates:
|
||||
error_message = (
|
||||
f"The model response was blocked due to"
|
||||
f" {response._raw_response.prompt_feedback.block_reason.name}.\n"
|
||||
f"Blocked reason message:"
|
||||
f" {response._raw_response.prompt_feedback.block_reason_message}.\n"
|
||||
"The input prompt may be blocked for safety reasons.\n"
|
||||
f"Prompt: {prompt}.\n"
|
||||
f"Attempt: {attempt + 1}/{max_attempts}"
|
||||
)
|
||||
_LOGGER.warning(error_message)
|
||||
break
|
||||
else:
|
||||
candidate = response.candidates[0]
|
||||
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
|
||||
error_message = (
|
||||
"The model response did not finish"
|
||||
" successfully.\n"
|
||||
f"Finish reason: {candidate.finish_reason}.\n"
|
||||
f"Finish message: {candidate.finish_message}.\n"
|
||||
f"Safety ratings: {candidate.safety_ratings}.\n"
|
||||
"Please adjust the model safety_settings, or"
|
||||
" try a different prompt.\n"
|
||||
f"Attempt: {attempt + 1}/{max_attempts}"
|
||||
)
|
||||
_LOGGER.warning(error_message)
|
||||
else:
|
||||
return response.candidates[0].content.parts[0].text
|
||||
except Exception as e:
|
||||
error_message = (
|
||||
f"Failed to generate response candidates from Gemini model"
|
||||
f" {model._model_name}.\n"
|
||||
f"Error: {e}.\n"
|
||||
f"Prompt: {prompt}.\n"
|
||||
f"Attempt: {attempt + 1}/{max_attempts}"
|
||||
)
|
||||
_LOGGER.warning(error_message)
|
||||
if attempt < max_attempts - 1:
|
||||
_LOGGER.info(
|
||||
f"Retrying response generation for prompt: {prompt}, attempt"
|
||||
f" {attempt + 1}/{max_attempts}..."
|
||||
)
|
||||
|
||||
final_error_message = (
|
||||
f"Failed to generate response from Gemini model {model._model_name}.\n"
|
||||
f"Prompt: {prompt}."
|
||||
)
|
||||
_LOGGER.error(final_error_message)
|
||||
return constants.RESPONSE_ERROR
|
||||
|
||||
|
||||
def _generate_responses_from_gemini_model(
|
||||
model: generative_models.GenerativeModel,
|
||||
df: "pd.DataFrame",
|
||||
rubric_generation_prompt_template: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Generates responses from Gemini model for the given evaluation dataset.
|
||||
|
||||
Args:
|
||||
model: The Gemini model instance.
|
||||
df: Evaluation Dataset.
|
||||
Returns:
|
||||
The list of model responses.
|
||||
"""
|
||||
_LOGGER.info(
|
||||
f"Generating a total of {df.shape[0]} "
|
||||
f"responses from Gemini model {model._model_name.split('/')[-1]}."
|
||||
)
|
||||
tasks = []
|
||||
with tqdm(total=len(df)) as pbar:
|
||||
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
|
||||
for idx, row in df.iterrows():
|
||||
if rubric_generation_prompt_template:
|
||||
input_columns = prompt_template_base.PromptTemplate(
|
||||
rubric_generation_prompt_template
|
||||
).variables
|
||||
if multimodal_utils.is_multimodal_instance(
|
||||
row[list(input_columns)].to_dict()
|
||||
):
|
||||
prompt = multimodal_utils._assemble_multi_modal_prompt(
|
||||
rubric_generation_prompt_template, row, idx, input_columns
|
||||
)
|
||||
else:
|
||||
prompt = _assemble_prompt(
|
||||
row, rubric_generation_prompt_template
|
||||
)
|
||||
else:
|
||||
prompt = row[constants.Dataset.PROMPT_COLUMN]
|
||||
task = executor.submit(
|
||||
_generate_content_text_response,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
)
|
||||
task.add_done_callback(lambda _: pbar.update(1))
|
||||
tasks.append(task)
|
||||
responses = [future.result() for future in tasks]
|
||||
return responses
|
||||
|
||||
|
||||
def _generate_response_from_custom_model_fn(
|
||||
model_fn: Callable[[str], str], eval_dataset: "pd.DataFrame"
|
||||
) -> List[str]:
|
||||
"""Generates responses from a custom model function.
|
||||
|
||||
Args:
|
||||
model_fn: The custom model function.
|
||||
eval_dataset: Evaluation Dataset.
|
||||
Returns:
|
||||
The list of model responses.
|
||||
"""
|
||||
max_workers = 5
|
||||
|
||||
_LOGGER.info(
|
||||
f"Generating a total of {eval_dataset.shape[0]} "
|
||||
"responses from the custom model function."
|
||||
)
|
||||
tasks = []
|
||||
try:
|
||||
with tqdm(total=len(eval_dataset)) as pbar:
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for _, row in eval_dataset.iterrows():
|
||||
task = executor.submit(
|
||||
model_fn, row[constants.Dataset.PROMPT_COLUMN]
|
||||
)
|
||||
task.add_done_callback(lambda _: pbar.update(1))
|
||||
tasks.append(task)
|
||||
except (ValueError, IndexError) as e:
|
||||
_LOGGER.warning(f"Failed to generate response from model function: {e}")
|
||||
|
||||
responses = [task.result() for task in tasks]
|
||||
return responses
|
||||
|
||||
|
||||
def populate_eval_dataset_with_model_responses(
|
||||
responses: List[str],
|
||||
evaluation_run_config: evaluation_base.EvaluationRunConfig,
|
||||
is_baseline_model: bool = False,
|
||||
) -> None:
|
||||
"""Populates the evaluation dataset with model responses.
|
||||
|
||||
Args:
|
||||
responses: The list of model responses.
|
||||
evaluation_run_config: Evaluation Run Configurations.
|
||||
is_baseline_model: Whether the model is a baseline model for
|
||||
PairwiseMetric.
|
||||
"""
|
||||
df = evaluation_run_config.dataset.copy()
|
||||
if is_baseline_model:
|
||||
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
|
||||
else:
|
||||
evaluation_run_config.dataset = df.assign(response=responses)
|
||||
|
||||
_LOGGER.info(
|
||||
f"All {evaluation_run_config.dataset.shape[0]} responses are successfully"
|
||||
f" generated from model."
|
||||
)
|
||||
|
||||
|
||||
def _check_variable_columns_exist(
|
||||
dataset_row: "pd.Series", variable_names_set: Set[str]
|
||||
) -> None:
|
||||
"""Checks if all variable names exist in the dataset columns.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to evaluate.
|
||||
variable_names_set: A set of variable names.
|
||||
|
||||
Raises:
|
||||
ValueError: If any variable names do not exist in the dataset columns
|
||||
or the prompt template is invalid.
|
||||
"""
|
||||
actual_column_names_set = set(dataset_row.to_dict().keys())
|
||||
if not variable_names_set.issubset(actual_column_names_set):
|
||||
missing_columns = variable_names_set - actual_column_names_set
|
||||
raise ValueError(
|
||||
"Failed to assemble prompt template: The following column(s) are"
|
||||
f" missing: {', '.join(missing_columns)}. "
|
||||
f"Please verify prompt_template variables {variable_names_set} and "
|
||||
f"evaluation dataset column names {actual_column_names_set}."
|
||||
)
|
||||
@@ -0,0 +1,238 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Autorater Utils Class and Functions."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview.evaluation import _base as evaluation_base
|
||||
from vertexai.preview.evaluation import eval_task
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
from vertexai.preview.tuning import sft
|
||||
from sklearn import metrics
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
AutoraterConfig = evaluation_base.AutoraterConfig
|
||||
AutoraterEvalResult = evaluation_base.AutoraterEvalResult
|
||||
EvalTask = eval_task.EvalTask
|
||||
PointwiseMetric = pointwise_metric.PointwiseMetric
|
||||
PairwiseMetric = pairwise_metric.PairwiseMetric
|
||||
_SCORE = "score"
|
||||
_METRIC = "metric"
|
||||
_PAIRWISE_CHOICE = "pairwise_choice"
|
||||
_HUMAN_RATING = "human_rating"
|
||||
_HUMAN_PAIRWISE_CHOICE = "human_pairwise_choice"
|
||||
_ACCURACY_BALANCED = "accuracy_balanced"
|
||||
_F1_SCORE_BALANCED = "f1_score_balanced"
|
||||
_CONFUSION_MATRIX = "confusion_matrix"
|
||||
_CONFUSION_MATRIX_LABELS = "confusion_matrix_labels"
|
||||
_METRICS_CATEGORY_LIMIT = 10
|
||||
_NAN = "nan"
|
||||
_ERROR = "error"
|
||||
|
||||
|
||||
def tune_autorater(
|
||||
*,
|
||||
base_model: Union[str, generative_models.GenerativeModel],
|
||||
train_dataset: str,
|
||||
validation_dataset: Optional[str] = None,
|
||||
tuned_model_display_name: Optional[str] = None,
|
||||
epochs: Optional[int] = None,
|
||||
learning_rate_multiplier: Optional[float] = None,
|
||||
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
time_out_hours: int = 10,
|
||||
) -> AutoraterConfig:
|
||||
"""Lora Tune an autorater model.
|
||||
|
||||
Args:
|
||||
base_model: Model name for tuning, e.g., "gemini-1.0-pro-002".
|
||||
train_dataset: Cloud Storage path to file containing training dataset for
|
||||
tuning. The dataset should be in JSONL format.
|
||||
validation_dataset: Cloud Storage path to file containing validation
|
||||
dataset for tuning. The dataset should be in JSONL format.
|
||||
tuned_model_display_name: The display name of the
|
||||
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
|
||||
128 characters long and can consist of any UTF-8 characters.
|
||||
epochs: Number of training epoches for this tuning job.
|
||||
learning_rate_multiplier: Learning rate multiplier for tuning.
|
||||
adapter_size: Adapter size for tuning.
|
||||
labels: User-defined metadata to be associated with trained models
|
||||
time_out_hours: Timeout in hours for tuning job. Default value is 10
|
||||
hours.
|
||||
|
||||
Returns:
|
||||
A `AutoraterConfig` object with tuned model endpoint.
|
||||
"""
|
||||
tune_job = sft.train(
|
||||
source_model=base_model,
|
||||
train_dataset=train_dataset,
|
||||
validation_dataset=validation_dataset,
|
||||
tuned_model_display_name=tuned_model_display_name,
|
||||
epochs=epochs,
|
||||
learning_rate_multiplier=learning_rate_multiplier,
|
||||
adapter_size=adapter_size,
|
||||
labels=labels,
|
||||
)
|
||||
time_out_seconds = time_out_hours * 60 * 60
|
||||
while not tune_job.refresh().has_ended and time_out_seconds > 0:
|
||||
time.sleep(60)
|
||||
time_out_seconds -= 60
|
||||
|
||||
if tune_job.has_succeeded:
|
||||
return AutoraterConfig(autorater_model=tune_job.tuned_model_endpoint_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Failed to tune autorater model. Please check the logs for more details."
|
||||
)
|
||||
|
||||
|
||||
def _get_evaluation_result(
|
||||
metric: Union[PointwiseMetric, PairwiseMetric],
|
||||
autorater_eval_results: List[str],
|
||||
human_eval_results: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get evaluation result for autorater."""
|
||||
filtered_autorater_eval_results = []
|
||||
filtered_human_eval_results = []
|
||||
for autorater_eval_result, human_eval_result in zip(
|
||||
autorater_eval_results, human_eval_results
|
||||
):
|
||||
# Filter failed pointwise evaluation results.
|
||||
if autorater_eval_result.lower() == _NAN or human_eval_result.lower() == _NAN:
|
||||
continue
|
||||
|
||||
# Filter failed pairwise evaluation results.
|
||||
if (
|
||||
autorater_eval_result.lower() == _ERROR
|
||||
or human_eval_result.lower() == _ERROR
|
||||
):
|
||||
continue
|
||||
|
||||
filtered_autorater_eval_results.append(autorater_eval_result)
|
||||
filtered_human_eval_results.append(human_eval_result)
|
||||
|
||||
labels = list(
|
||||
sorted(set(filtered_autorater_eval_results) | set(filtered_human_eval_results))
|
||||
)
|
||||
eval_result = {_METRIC: metric.metric_name}
|
||||
eval_result[_ACCURACY_BALANCED] = metrics.balanced_accuracy_score(
|
||||
filtered_human_eval_results, filtered_autorater_eval_results
|
||||
)
|
||||
eval_result[_F1_SCORE_BALANCED] = metrics.f1_score(
|
||||
filtered_human_eval_results,
|
||||
filtered_autorater_eval_results,
|
||||
average="weighted",
|
||||
)
|
||||
if len(labels) > _METRICS_CATEGORY_LIMIT:
|
||||
logging.warning(
|
||||
"Confusion matrix is not provided as the number of"
|
||||
" rating rubric values %d is greater than the limit %d.",
|
||||
len(labels),
|
||||
_METRICS_CATEGORY_LIMIT,
|
||||
)
|
||||
else:
|
||||
eval_result[_CONFUSION_MATRIX] = metrics.confusion_matrix(
|
||||
filtered_human_eval_results,
|
||||
filtered_autorater_eval_results,
|
||||
labels=labels,
|
||||
)
|
||||
eval_result[_CONFUSION_MATRIX_LABELS] = labels
|
||||
return eval_result
|
||||
|
||||
|
||||
def evaluate_autorater(
|
||||
*,
|
||||
evaluate_autorater_input: "pd.DataFrame",
|
||||
eval_metrics: List[Union[PointwiseMetric, PairwiseMetric]],
|
||||
autorater_config: Optional[AutoraterConfig] = None,
|
||||
eval_dataset_metadata: Dict[str, Any] = None,
|
||||
**kwargs,
|
||||
) -> AutoraterEvalResult:
|
||||
"""Evaluates the autorater model using human evaluation results.
|
||||
|
||||
Args:
|
||||
evaluate_autorater_input: Autorater evaluation input, including
|
||||
evaluation results from human evaluation and autorater model.
|
||||
eval_metrics: List of model based metrics.
|
||||
autorater_config: Autorater configuration.
|
||||
eval_dataset_metadata: Evaluation dataset metadata.
|
||||
**kwargs: Additional arguments added to AutoraterEvalResult.
|
||||
|
||||
Returns:
|
||||
Autorater evalaution result .
|
||||
"""
|
||||
eval_result = []
|
||||
for metric in eval_metrics:
|
||||
if isinstance(metric, PointwiseMetric):
|
||||
autorater_score = list(
|
||||
map(
|
||||
lambda x: str(float(x)),
|
||||
list(evaluate_autorater_input[metric.metric_name + "/" + _SCORE]),
|
||||
)
|
||||
)
|
||||
human_score = list(
|
||||
map(
|
||||
lambda x: str(float(x)),
|
||||
list(
|
||||
evaluate_autorater_input[
|
||||
metric.metric_name + "/" + _HUMAN_RATING
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
eval_result.append(
|
||||
_get_evaluation_result(metric, autorater_score, human_score)
|
||||
)
|
||||
elif isinstance(metric, PairwiseMetric):
|
||||
autorater_choice = list(
|
||||
map(
|
||||
str,
|
||||
list(
|
||||
evaluate_autorater_input[
|
||||
metric.metric_name + "/" + _PAIRWISE_CHOICE
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
human_choice = list(
|
||||
map(
|
||||
str,
|
||||
list(
|
||||
evaluate_autorater_input[
|
||||
metric.metric_name + "/" + _HUMAN_PAIRWISE_CHOICE
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
eval_result.append(
|
||||
_get_evaluation_result(metric, autorater_choice, human_choice)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
return AutoraterEvalResult(
|
||||
eval_result=eval_result,
|
||||
eval_dataset_metadata=eval_dataset_metadata,
|
||||
autorater_config=autorater_config,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,206 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Constants for evaluation."""
|
||||
import dataclasses
|
||||
|
||||
# The number of concurrent workers to use for making model inference and
|
||||
# evaluation requests.
|
||||
MAX_WORKERS = 100
|
||||
RESPONSE_ERROR = "Error"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Metric:
|
||||
"""Namespace for Metrics."""
|
||||
|
||||
# Model-based Pointwise Metrics.
|
||||
COHERENCE = "coherence"
|
||||
FLUENCY = "fluency"
|
||||
SAFETY = "safety"
|
||||
GROUNDEDNESS = "groundedness"
|
||||
INSTRUCTION_FOLLOWING = "instruction_following"
|
||||
VERBOSITY = "verbosity"
|
||||
TEXT_QUALITY = "text_quality"
|
||||
SUMMARIZATION_QUALITY = "summarization_quality"
|
||||
QUESTION_ANSWERING_QUALITY = "question_answering_quality"
|
||||
MULTI_TURN_CHAT_QUALITY = "multi_turn_chat_quality"
|
||||
MULTI_TURN_SAFETY = "multi_turn_safety"
|
||||
RUBRIC_BASED_INSTRUCTION_FOLLOWING = "rubric_based_instruction_following"
|
||||
|
||||
# Model-based Pairwise Metrics.
|
||||
PAIRWISE_COHERENCE = "pairwise_coherence"
|
||||
PAIRWISE_FLUENCY = "pairwise_fluency"
|
||||
PAIRWISE_SAFETY = "pairwise_safety"
|
||||
PAIRWISE_GROUNDEDNESS = "pairwise_groundedness"
|
||||
PAIRWISE_INSTRUCTION_FOLLOWING = "pairwise_instruction_following"
|
||||
PAIRWISE_VERBOSITY = "pairwise_verbosity"
|
||||
PAIRWISE_TEXT_QUALITY = "pairwise_text_quality"
|
||||
PAIRWISE_SUMMARIZATION_QUALITY = "pairwise_summarization_quality"
|
||||
PAIRWISE_QUESTION_ANSWERING_QUALITY = "pairwise_question_answering_quality"
|
||||
PAIRWISE_MULTI_TURN_CHAT_QUALITY = "pairwise_multi_turn_chat_quality"
|
||||
PAIRWISE_MULTI_TURN_SAFETY = "pairwise_multi_turn_safety"
|
||||
|
||||
POINTWISE_METRIC = "pointwise_metric"
|
||||
PAIRWISE_METRIC = "pairwise_metric"
|
||||
|
||||
# Automatic Metrics.
|
||||
EXACT_MATCH = "exact_match"
|
||||
BLEU = "bleu"
|
||||
ROUGE = "rouge"
|
||||
ROUGE_1 = "rouge_1"
|
||||
ROUGE_2 = "rouge_2"
|
||||
ROUGE_L = "rouge_l"
|
||||
ROUGE_L_SUM = "rouge_l_sum"
|
||||
TOOL_CALL_VALID = "tool_call_valid"
|
||||
TOOL_NAME_MATCH = "tool_name_match"
|
||||
TOOL_PARAMETER_KEY_MATCH = "tool_parameter_key_match"
|
||||
TOOL_PARAMETER_KV_MATCH = "tool_parameter_kv_match"
|
||||
TRAJECTORY_EXACT_MATCH = "trajectory_exact_match"
|
||||
TRAJECTORY_IN_ORDER_MATCH = "trajectory_in_order_match"
|
||||
TRAJECTORY_ANY_ORDER_MATCH = "trajectory_any_order_match"
|
||||
TRAJECTORY_PRECISION = "trajectory_precision"
|
||||
TRAJECTORY_RECALL = "trajectory_recall"
|
||||
TRAJECTORY_SINGLE_TOOL_USE = "trajectory_single_tool_use"
|
||||
LATENCY = "latency_in_seconds"
|
||||
FAILURE = "failure"
|
||||
|
||||
AUTOMATIC_METRIC_LIST = (
|
||||
EXACT_MATCH,
|
||||
BLEU,
|
||||
ROUGE,
|
||||
ROUGE_1,
|
||||
ROUGE_2,
|
||||
ROUGE_L,
|
||||
ROUGE_L_SUM,
|
||||
TOOL_CALL_VALID,
|
||||
TOOL_NAME_MATCH,
|
||||
TOOL_PARAMETER_KEY_MATCH,
|
||||
TOOL_PARAMETER_KV_MATCH,
|
||||
)
|
||||
|
||||
TRAJECTORY_METRIC_LIST = (
|
||||
TRAJECTORY_EXACT_MATCH,
|
||||
TRAJECTORY_IN_ORDER_MATCH,
|
||||
TRAJECTORY_ANY_ORDER_MATCH,
|
||||
TRAJECTORY_PRECISION,
|
||||
TRAJECTORY_RECALL,
|
||||
TRAJECTORY_SINGLE_TOOL_USE,
|
||||
)
|
||||
DEFAULT_METRIC_LIST = (
|
||||
LATENCY,
|
||||
FAILURE,
|
||||
)
|
||||
|
||||
POINTWISE_METRIC_PROMPT_TEMPLATE_EXAMPLE_LIST = (
|
||||
COHERENCE,
|
||||
FLUENCY,
|
||||
SAFETY,
|
||||
GROUNDEDNESS,
|
||||
INSTRUCTION_FOLLOWING,
|
||||
VERBOSITY,
|
||||
TEXT_QUALITY,
|
||||
SUMMARIZATION_QUALITY,
|
||||
QUESTION_ANSWERING_QUALITY,
|
||||
MULTI_TURN_CHAT_QUALITY,
|
||||
MULTI_TURN_SAFETY,
|
||||
)
|
||||
|
||||
PAIRWISE_METRIC_PROMPT_TEMPLATE_EXAMPLE_LIST = (
|
||||
PAIRWISE_COHERENCE,
|
||||
PAIRWISE_FLUENCY,
|
||||
PAIRWISE_SAFETY,
|
||||
PAIRWISE_GROUNDEDNESS,
|
||||
PAIRWISE_INSTRUCTION_FOLLOWING,
|
||||
PAIRWISE_VERBOSITY,
|
||||
PAIRWISE_TEXT_QUALITY,
|
||||
PAIRWISE_SUMMARIZATION_QUALITY,
|
||||
PAIRWISE_QUESTION_ANSWERING_QUALITY,
|
||||
PAIRWISE_MULTI_TURN_CHAT_QUALITY,
|
||||
PAIRWISE_MULTI_TURN_SAFETY,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MetricResult:
|
||||
ROW_COUNT_KEY = "row_count"
|
||||
SCORE_KEY = "score"
|
||||
EXPLANATION_KEY = "explanation"
|
||||
CUSTOM_OUTPUT_KEY = "custom_output"
|
||||
RAW_OUTPUT_KEY = "raw_output"
|
||||
RAW_OUTPUTS_KEY = "raw_outputs"
|
||||
PAIRWISE_CHOICE_KEY = "pairwise_choice"
|
||||
IS_UNSAFE_KEY = "is_unsafe"
|
||||
IS_UNSAFE_PROBABILITY_KEY = "is_unsafe_probability"
|
||||
VIOLATED_POLICIES_KEY = "violated_policies"
|
||||
RUBRIC_LEVEL_INSTRUCTION_FOLLOWING_KEY = "per_rubric_result"
|
||||
|
||||
# Automatic Metrics.
|
||||
EXACT_MATCH_RESULTS = "exact_match_results"
|
||||
BLEU_RESULTS = "bleu_results"
|
||||
ROUGE_RESULTS = "rouge_results"
|
||||
TOOL_CALL_VALID_RESULTS = "tool_call_valid_results"
|
||||
TOOL_NAME_MATCH_RESULTS = "tool_name_match_results"
|
||||
TOOL_PARAMETER_KEY_MATCH_RESULTS = "tool_parameter_key_match_results"
|
||||
TOOL_PARAMETER_KV_MATCH_RESULTS = "tool_parameter_kv_match_results"
|
||||
TRAJECTORY_EXACT_MATCH_RESULTS = "trajectory_exact_match_results"
|
||||
TRAJECTORY_IN_ORDER_MATCH_RESULTS = "trajectory_in_order_match_results"
|
||||
TRAJECTORY_ANY_ORDER_MATCH_RESULTS = "trajectory_any_order_match_results"
|
||||
TRAJECTORY_PRECISION_RESULTS = "trajectory_precision_results"
|
||||
TRAJECTORY_RECALL_RESULTS = "trajectory_recall_results"
|
||||
TRAJECTORY_SINGLE_TOOL_USE_RESULTS = "trajectory_single_tool_use_results"
|
||||
|
||||
POINTWISE_METRIC_RESULT = "pointwise_metric_result"
|
||||
PAIRWISE_METRIC_RESULT = "pairwise_metric_result"
|
||||
RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT = (
|
||||
"rubric_based_instruction_following_result"
|
||||
)
|
||||
|
||||
AUTOMATIC_METRIC_RESULTS_LIST = (
|
||||
EXACT_MATCH_RESULTS,
|
||||
BLEU_RESULTS,
|
||||
ROUGE_RESULTS,
|
||||
TOOL_CALL_VALID_RESULTS,
|
||||
TOOL_NAME_MATCH_RESULTS,
|
||||
TOOL_PARAMETER_KEY_MATCH_RESULTS,
|
||||
TOOL_PARAMETER_KV_MATCH_RESULTS,
|
||||
TRAJECTORY_EXACT_MATCH_RESULTS,
|
||||
TRAJECTORY_IN_ORDER_MATCH_RESULTS,
|
||||
TRAJECTORY_ANY_ORDER_MATCH_RESULTS,
|
||||
TRAJECTORY_PRECISION_RESULTS,
|
||||
TRAJECTORY_RECALL_RESULTS,
|
||||
TRAJECTORY_SINGLE_TOOL_USE_RESULTS,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Dataset:
|
||||
# Default evaluation dataset schema column names.
|
||||
MODEL_RESPONSE_COLUMN = "response"
|
||||
BASELINE_MODEL_RESPONSE_COLUMN = "baseline_model_response"
|
||||
PROMPT_COLUMN = "prompt"
|
||||
REFERENCE_COLUMN = "reference"
|
||||
PREDICTED_TRAJECTORY_COLUMN = "predicted_trajectory"
|
||||
REFERENCE_TRAJECTORY_COLUMN = "reference_trajectory"
|
||||
RUBRICS_COLUMN = "rubrics"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class QuotaLimit:
|
||||
"""Generative AI on Vertex AI quota limits."""
|
||||
|
||||
# Default Evaluation Service QPS limit.
|
||||
EVAL_SERVICE_QPS = 10
|
||||
@@ -0,0 +1,630 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Evaluation Task class."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from google.api_core import exceptions
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.metadata import metadata
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview import reasoning_engines
|
||||
from vertexai.preview.evaluation import _base as eval_base
|
||||
from vertexai.preview.evaluation import _evaluation
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation import utils as eval_utils
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_base as metrics_base,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
rubric_based_metric,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
from google.colab import sheets
|
||||
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
from IPython import display as IPython_display
|
||||
except ImportError:
|
||||
IPython_display = None
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
AutoraterConfig = eval_base.AutoraterConfig
|
||||
EvalResult = eval_base.EvalResult
|
||||
GenerativeModel = generative_models.GenerativeModel
|
||||
|
||||
_RunnableType = Union[reasoning_engines.Queryable, Callable[[str], Dict[str, str]]]
|
||||
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
|
||||
|
||||
|
||||
class EvalTask:
|
||||
"""A class representing an EvalTask.
|
||||
|
||||
An evaluation task assesses the ability of a Gen AI model, agent or
|
||||
application to perform a specific task in response to prompts.
|
||||
Each evaluation task includes an evaluation dataset, which can be a set of
|
||||
test cases and a set of metrics for assessment. These tasks provide the
|
||||
framework for running evaluations in a standardized and repeatable way,
|
||||
allowing for comparative assessment with varying run-specific parameters.
|
||||
|
||||
|
||||
Dataset Details:
|
||||
|
||||
Default dataset column names:
|
||||
* prompt_column_name: "prompt"
|
||||
* reference_column_name: "reference"
|
||||
* response_column_name: "response"
|
||||
* baseline_model_response_column_name: "baseline_model_response"
|
||||
* rubrics_column_name: "rubrics"
|
||||
|
||||
|
||||
Requirement for different use cases:
|
||||
* Bring-your-own-response (BYOR): You already have the data that you
|
||||
want to evaluate stored in the dataset. Response column name can be
|
||||
customized by providing `response_column_name` parameter, or in the
|
||||
`metric_column_mapping`. For BYOR pairwise evaluation, the baseline
|
||||
model response column name can be customized by providing
|
||||
`baseline_model_response_column_name` parameter, or
|
||||
in the `metric_column_mapping`. If the `response` column or
|
||||
`baseline_model_response` column is present while the
|
||||
corresponding model is specified, an error will be raised.
|
||||
|
||||
* Perform model/agent inference without a prompt template: You have a dataset
|
||||
containing the input prompts to the model/agent and want to perform
|
||||
inference before evaluation. A column named `prompt` is required
|
||||
in the evaluation dataset and is used directly as input to the model/agent.
|
||||
|
||||
* Perform model/agent inference with a prompt template: You have a dataset
|
||||
containing the input variables to the prompt template and want to
|
||||
assemble the prompts for inference. Evaluation dataset
|
||||
must contain column names corresponding to the variable names in
|
||||
the prompt template. For example, if prompt template is
|
||||
"Instruction: {instruction}, context: {context}", the dataset must
|
||||
contain `instruction` and `context` columns.
|
||||
|
||||
Metrics Details:
|
||||
|
||||
The supported metrics descriptions, rating rubrics, and the required
|
||||
input variables can be found on the Vertex AI public documentation page.
|
||||
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
|
||||
|
||||
Usage Examples:
|
||||
|
||||
1. To perform bring-your-own-response(BYOR) evaluation, provide the model
|
||||
responses in the `response` column in the dataset. If a pairwise metric is
|
||||
used for BYOR evaluation, provide the baseline model responses in the
|
||||
`baseline_model_response` column.
|
||||
|
||||
```
|
||||
eval_dataset = pd.DataFrame({
|
||||
"prompt" : [...],
|
||||
"reference": [...],
|
||||
"response" : [...],
|
||||
"baseline_model_response": [...],
|
||||
})
|
||||
eval_task = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[
|
||||
"bleu",
|
||||
"rouge_l_sum",
|
||||
MetricPromptTemplateExamples.Pointwise.FLUENCY,
|
||||
MetricPromptTemplateExamples.Pairwise.SAFETY
|
||||
],
|
||||
experiment="my-experiment",
|
||||
)
|
||||
eval_result = eval_task.evaluate(experiment_run_name="eval-experiment-run")
|
||||
```
|
||||
|
||||
2. To perform evaluation with Gemini model inference, specify the `model`
|
||||
parameter with a `GenerativeModel` instance. The input column name to the
|
||||
model is `prompt` and must be present in the dataset.
|
||||
|
||||
```
|
||||
eval_dataset = pd.DataFrame({
|
||||
"reference": [...],
|
||||
"prompt" : [...],
|
||||
})
|
||||
result = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=["exact_match", "bleu", "rouge_1", "rouge_l_sum"],
|
||||
experiment="my-experiment",
|
||||
).evaluate(
|
||||
model=GenerativeModel("gemini-1.5-pro"),
|
||||
experiment_run_name="gemini-eval-run"
|
||||
)
|
||||
```
|
||||
|
||||
3. If a `prompt_template` is specified, the `prompt` column is not required.
|
||||
Prompts can be assembled from the evaluation dataset, and all prompt
|
||||
template variable names must be present in the dataset columns.
|
||||
```
|
||||
eval_dataset = pd.DataFrame({
|
||||
"context" : [...],
|
||||
"instruction": [...],
|
||||
})
|
||||
result = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY],
|
||||
).evaluate(
|
||||
model=GenerativeModel("gemini-1.5-pro"),
|
||||
prompt_template="{instruction}. Article: {context}. Summary:",
|
||||
)
|
||||
```
|
||||
|
||||
4. To perform evaluation with custom model inference, specify the `model`
|
||||
parameter with a custom inference function. The input column name to the
|
||||
custom inference function is `prompt` and must be present in the dataset.
|
||||
|
||||
```
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
def custom_model_fn(input: str) -> str:
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": input}
|
||||
]
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
eval_dataset = pd.DataFrame({
|
||||
"prompt" : [...],
|
||||
"reference": [...],
|
||||
})
|
||||
result = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[MetricPromptTemplateExamples.Pointwise.SAFETY],
|
||||
experiment="my-experiment",
|
||||
).evaluate(
|
||||
model=custom_model_fn,
|
||||
experiment_run_name="gpt-eval-run"
|
||||
)
|
||||
```
|
||||
|
||||
5. To perform pairwise metric evaluation with model inference step, specify
|
||||
the `baseline_model` input to a `PairwiseMetric` instance and the candidate
|
||||
`model` input to the `EvalTask.evaluate()` function. The input column name
|
||||
to both models is `prompt` and must be present in the dataset.
|
||||
|
||||
```
|
||||
baseline_model = GenerativeModel("gemini-1.0-pro")
|
||||
candidate_model = GenerativeModel("gemini-1.5-pro")
|
||||
|
||||
pairwise_groundedness = PairwiseMetric(
|
||||
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
|
||||
"pairwise_groundedness"
|
||||
),
|
||||
baseline_model=baseline_model,
|
||||
)
|
||||
eval_dataset = pd.DataFrame({
|
||||
"prompt" : [...],
|
||||
})
|
||||
result = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[pairwise_groundedness],
|
||||
experiment="my-pairwise-experiment",
|
||||
).evaluate(
|
||||
model=candidate_model,
|
||||
experiment_run_name="gemini-pairwise-eval-run",
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dataset: Union["pd.DataFrame", str, Dict[str, Any], "sheets.InteractiveSheet"],
|
||||
metrics: List[
|
||||
Union[
|
||||
Literal[
|
||||
"exact_match",
|
||||
"bleu",
|
||||
"rouge_1",
|
||||
"rouge_2",
|
||||
"rouge_l",
|
||||
"rouge_l_sum",
|
||||
"tool_call_valid",
|
||||
"tool_name_match",
|
||||
"tool_parameter_key_match",
|
||||
"tool_parameter_kv_match",
|
||||
"trajectory_exact_match",
|
||||
"trajectory_in_order_match",
|
||||
"trajectory_any_order_match",
|
||||
"trajectory_precision",
|
||||
"trajectory_recall",
|
||||
"rubric_based_instruction_following",
|
||||
],
|
||||
metrics_base.CustomMetric,
|
||||
metrics_base._AutomaticMetric,
|
||||
pointwise_metric.PointwiseMetric,
|
||||
pairwise_metric.PairwiseMetric,
|
||||
rubric_based_metric.RubricBasedMetric,
|
||||
]
|
||||
],
|
||||
experiment: Optional[str] = None,
|
||||
metric_column_mapping: Optional[Dict[str, str]] = None,
|
||||
output_uri_prefix: Optional[str] = "",
|
||||
autorater_config: Optional[AutoraterConfig] = None,
|
||||
):
|
||||
"""Initializes an EvalTask.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to be evaluated.
|
||||
Supports the following dataset formats:
|
||||
* pandas.DataFrame: Used directly for evaluation.
|
||||
* Dict: Converted to a pandas DataFrame before evaluation.
|
||||
* str: Interpreted as a file path or URI. Supported formats include:
|
||||
* Local JSONL or CSV files: Loaded from the local filesystem.
|
||||
* GCS JSONL or CSV files: Loaded from Google Cloud Storage
|
||||
(e.g., 'gs://bucket/data.csv').
|
||||
* BigQuery table URI: Loaded from Google Cloud BigQuery
|
||||
(e.g., 'bq://project-id.dataset.table_name').
|
||||
metrics: The list of metric names, or Metric instances to evaluate.
|
||||
Prompt template is required for PairwiseMetric.
|
||||
experiment: The name of the experiment to log the evaluations to.
|
||||
metric_column_mapping: An optional dictionary column mapping that
|
||||
overrides the metric prompt template input variable names with
|
||||
mapped the evaluation dataset column names, used during evaluation.
|
||||
For example, if the input_variables of the metric prompt template
|
||||
are ["context", "reference"], the metric_column_mapping can be
|
||||
{
|
||||
"context": "news_context",
|
||||
"reference": "ground_truth",
|
||||
"response": "model_1_response"
|
||||
}
|
||||
if the dataset has columns "news_context", "ground_truth" and
|
||||
"model_1_response".
|
||||
output_uri_prefix: GCS location to store the metrics_table from
|
||||
evaluation results.
|
||||
autorater_config: The autorater config for model based evaluation.
|
||||
If autorater config is specified on a metric, it will override the
|
||||
autorater config specified here.
|
||||
"""
|
||||
self._dataset = eval_utils.load_dataset(dataset)
|
||||
self._metrics = metrics
|
||||
self._experiment = experiment
|
||||
self._metric_column_mapping = eval_utils.initialize_metric_column_mapping(
|
||||
metric_column_mapping, self._dataset
|
||||
)
|
||||
self.output_uri_prefix = output_uri_prefix
|
||||
self._autorater_config = autorater_config
|
||||
|
||||
@property
|
||||
def dataset(self) -> "pd.DataFrame":
|
||||
"""Returns evaluation dataset."""
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def metrics(self) -> List[Union[str, metrics_base.CustomMetric]]:
|
||||
"""Returns metrics."""
|
||||
return self._metrics
|
||||
|
||||
@property
|
||||
def autorater_config(self) -> Optional[AutoraterConfig]:
|
||||
"""Returns autorater config."""
|
||||
return self._autorater_config
|
||||
|
||||
@property
|
||||
def experiment(self) -> Optional[str]:
|
||||
"""Returns experiment name."""
|
||||
return self._experiment
|
||||
|
||||
def _evaluate_with_experiment(
|
||||
self,
|
||||
model: Optional[_ModelType] = None,
|
||||
runnable: Optional[_RunnableType] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
experiment_run_name: Optional[str] = None,
|
||||
evaluation_service_qps: Optional[float] = None,
|
||||
retry_timeout: float = 120.0,
|
||||
output_file_name: Optional[str] = None,
|
||||
) -> EvalResult:
|
||||
"""Runs an evaluation for the EvalTask with an experiment.
|
||||
|
||||
Args:
|
||||
model: A GenerativeModel instance or a custom model function to generate
|
||||
responses to evaluate. If not provided, the evaluation is computed with
|
||||
the `response` column in the `dataset`.
|
||||
runnable: The runnable to generate responses to evaluate. If not provided,
|
||||
the evaluation is computed with the `response` and/or `predicted_trajectory`
|
||||
column in the `dataset`.
|
||||
prompt_template: The prompt template to use for the evaluation. If not
|
||||
set, the prompt template that was used to create the EvalTask will be
|
||||
used.
|
||||
experiment_run_name: The name of the experiment run to log the evaluation
|
||||
to if an experiment is set for this EvalTask. If not provided, a random
|
||||
unique experiment run name is used.
|
||||
evaluation_service_qps: The custom QPS limit for the evaluation service.
|
||||
retry_timeout: How long to keep retrying the evaluation requests for
|
||||
the whole evaluation dataset, in seconds.
|
||||
output_path: The file name with csv suffix to store the output
|
||||
metrics_table to be tracked in the experiment run.
|
||||
|
||||
Returns:
|
||||
The evaluation result.
|
||||
"""
|
||||
self._validate_experiment_run()
|
||||
with vertexai.preview.start_run(experiment_run_name):
|
||||
self._log_eval_experiment_param(
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
output_file_name=output_file_name,
|
||||
)
|
||||
eval_result = _evaluation.evaluate(
|
||||
dataset=self._dataset,
|
||||
metrics=self._metrics,
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
metric_column_mapping=self._metric_column_mapping,
|
||||
evaluation_service_qps=evaluation_service_qps,
|
||||
retry_timeout=retry_timeout,
|
||||
autorater_config=self._autorater_config,
|
||||
)
|
||||
|
||||
eval_result.summary_metrics = {
|
||||
k: ("NaN" if isinstance(v, float) and np.isnan(v) else v)
|
||||
for k, v in eval_result.summary_metrics.items()
|
||||
}
|
||||
eval_result.metadata = {
|
||||
"experiment": self._experiment,
|
||||
"experiment_run": experiment_run_name,
|
||||
}
|
||||
try:
|
||||
vertexai.preview.log_metrics(eval_result.summary_metrics)
|
||||
except (TypeError, exceptions.InvalidArgument) as e:
|
||||
_LOGGER.warning(f"Experiment metrics logging failed: {str(e)}")
|
||||
return eval_result
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
*,
|
||||
model: Optional[_ModelType] = None,
|
||||
runnable: Optional[_RunnableType] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
experiment_run_name: Optional[str] = None,
|
||||
response_column_name: Optional[str] = None,
|
||||
baseline_model_response_column_name: Optional[str] = None,
|
||||
evaluation_service_qps: Optional[float] = None,
|
||||
retry_timeout: float = 120.0,
|
||||
output_file_name: Optional[str] = "",
|
||||
) -> EvalResult:
|
||||
"""Runs an evaluation for the EvalTask.
|
||||
|
||||
Args:
|
||||
model: A GenerativeModel instance or a custom model function to generate
|
||||
responses to evaluate. If not provided, the evaluation can be performed
|
||||
in the bring-your-own-response (BYOR) mode.
|
||||
runnable: The runnable to generate responses to evaluate. If not provided,
|
||||
the evaluation is computed with the `response` and/or `predicted_trajectory`
|
||||
column in the `dataset`.
|
||||
prompt_template: The prompt template to use for the evaluation. If not
|
||||
set, the prompt template that was used to create the EvalTask will be
|
||||
used.
|
||||
experiment_run_name: The name of the experiment run to log the evaluation
|
||||
to if an experiment is set for this EvalTask. If not provided, a random
|
||||
unique experiment run name is used.
|
||||
response_column_name: The column name of model response in the dataset. If
|
||||
provided, this will override the `metric_column_mapping` of the `EvalTask`.
|
||||
baseline_model_response_column_name: The column name of baseline model
|
||||
response in the dataset for pairwise metrics. If provided, this will
|
||||
override the `metric_column_mapping` of the `EvalTask`
|
||||
evaluation_service_qps: The custom QPS limit for the evaluation service.
|
||||
retry_timeout: How long to keep retrying the evaluation requests for
|
||||
the whole evaluation dataset, in seconds.
|
||||
output_file_name: The file name with csv suffix to store the output
|
||||
metrics_table.
|
||||
|
||||
Returns:
|
||||
The evaluation result.
|
||||
"""
|
||||
global_experiment_name = (
|
||||
metadata._experiment_tracker.experiment_name
|
||||
) # pylint: disable=protected-access
|
||||
if experiment_run_name and not self._experiment and not global_experiment_name:
|
||||
raise ValueError(
|
||||
"Experiment is not set. Please initialize EvalTask with an"
|
||||
" experiment, or initialize a global experiment with "
|
||||
"`vertexai.init(experiment='experiment_name')`for logging this"
|
||||
" evaluation run."
|
||||
)
|
||||
|
||||
self._verify_and_set_response_column_name(
|
||||
response_column_name=response_column_name,
|
||||
metric_column_mapping_key=constants.Dataset.MODEL_RESPONSE_COLUMN,
|
||||
)
|
||||
self._verify_and_set_response_column_name(
|
||||
response_column_name=baseline_model_response_column_name,
|
||||
metric_column_mapping_key=constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
|
||||
)
|
||||
if self.output_uri_prefix and not output_file_name:
|
||||
output_file_name = f"eval_results_{utils.timestamped_unique_name()}.csv"
|
||||
experiment_run_name = experiment_run_name or f"{uuid.uuid4()}"
|
||||
if self._experiment and global_experiment_name:
|
||||
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
|
||||
experiment=self._experiment, backing_tensorboard=False
|
||||
)
|
||||
eval_result = self._evaluate_with_experiment(
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
experiment_run_name=experiment_run_name,
|
||||
evaluation_service_qps=evaluation_service_qps,
|
||||
retry_timeout=retry_timeout,
|
||||
output_file_name=output_file_name,
|
||||
)
|
||||
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
|
||||
experiment=global_experiment_name, backing_tensorboard=False
|
||||
)
|
||||
elif self._experiment and not global_experiment_name:
|
||||
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
|
||||
experiment=self._experiment, backing_tensorboard=False
|
||||
)
|
||||
eval_result = self._evaluate_with_experiment(
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
experiment_run_name=experiment_run_name,
|
||||
evaluation_service_qps=evaluation_service_qps,
|
||||
retry_timeout=retry_timeout,
|
||||
output_file_name=output_file_name,
|
||||
)
|
||||
metadata._experiment_tracker.reset() # pylint: disable=protected-access
|
||||
elif not self._experiment and global_experiment_name:
|
||||
eval_result = self._evaluate_with_experiment(
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
experiment_run_name=experiment_run_name,
|
||||
evaluation_service_qps=evaluation_service_qps,
|
||||
retry_timeout=retry_timeout,
|
||||
output_file_name=output_file_name,
|
||||
)
|
||||
else:
|
||||
eval_result = _evaluation.evaluate(
|
||||
dataset=self._dataset,
|
||||
metrics=self._metrics,
|
||||
model=model,
|
||||
runnable=runnable,
|
||||
prompt_template=prompt_template,
|
||||
metric_column_mapping=self._metric_column_mapping,
|
||||
evaluation_service_qps=evaluation_service_qps,
|
||||
retry_timeout=retry_timeout,
|
||||
autorater_config=self._autorater_config,
|
||||
)
|
||||
eval_utils.upload_evaluation_results(
|
||||
eval_result, self.output_uri_prefix, output_file_name
|
||||
)
|
||||
return eval_result
|
||||
|
||||
def _validate_experiment_run(self) -> None:
|
||||
"""Checks if an experiment run already exists."""
|
||||
if (
|
||||
metadata._experiment_tracker.experiment_run
|
||||
): # pylint: disable=protected-access
|
||||
raise ValueError(
|
||||
"Experiment run already exists. Please specify the name of the"
|
||||
" experiment run to assign current session within this evaluation."
|
||||
)
|
||||
|
||||
def _log_eval_experiment_param(
|
||||
self,
|
||||
model: _ModelType = None,
|
||||
runnable: _RunnableType = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
output_file_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Logs variable input parameters of an evaluation to an experiment run."""
|
||||
eval_metadata = {}
|
||||
if prompt_template is not None:
|
||||
eval_metadata.update({"prompt_template": prompt_template})
|
||||
|
||||
if model:
|
||||
if isinstance(model, GenerativeModel):
|
||||
eval_metadata.update(
|
||||
{
|
||||
"model_name": model._model_name, # pylint: disable=protected-access
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
model._generation_config
|
||||
and isinstance( # pylint: disable=protected-access
|
||||
model._generation_config,
|
||||
dict, # pylint: disable=protected-access
|
||||
)
|
||||
):
|
||||
eval_metadata.update(
|
||||
**model._generation_config
|
||||
) # pylint: disable=protected-access
|
||||
|
||||
if model._safety_settings and isinstance(
|
||||
model._safety_settings, dict
|
||||
): # pylint: disable=protected-access
|
||||
safety_settings = (
|
||||
model._safety_settings
|
||||
) # pylint: disable=protected-access
|
||||
safety_settings_as_str = {
|
||||
category.name: threshold.name
|
||||
for category, threshold in safety_settings.items()
|
||||
}
|
||||
eval_metadata.update(safety_settings_as_str)
|
||||
|
||||
if runnable:
|
||||
if isinstance(runnable, reasoning_engines.LangchainAgent):
|
||||
eval_metadata.update(
|
||||
{
|
||||
"model_name": runnable._model_name,
|
||||
"tools": runnable._tools,
|
||||
} # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
if self.output_uri_prefix and output_file_name:
|
||||
eval_metadata.update(
|
||||
{"output_file": self.output_uri_prefix + "/" + output_file_name}
|
||||
)
|
||||
|
||||
if eval_metadata:
|
||||
_LOGGER.info(
|
||||
f"Logging Eval experiment evaluation metadata: {eval_metadata}"
|
||||
)
|
||||
try:
|
||||
vertexai.preview.log_params(eval_metadata)
|
||||
except (ValueError, TypeError) as e:
|
||||
_LOGGER.warning(
|
||||
f"Experiment evaluation metadata logging failed: {str(e)}"
|
||||
)
|
||||
|
||||
def _verify_and_set_response_column_name(
|
||||
self, response_column_name: str, metric_column_mapping_key: str
|
||||
) -> None:
|
||||
"""Verifies and sets the model response column names."""
|
||||
if response_column_name:
|
||||
if response_column_name in self._dataset.columns:
|
||||
self._metric_column_mapping[
|
||||
metric_column_mapping_key
|
||||
] = response_column_name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"(Baseline) Model response column {response_column_name} is not"
|
||||
" found in the dataset."
|
||||
)
|
||||
|
||||
def display_runs(self):
|
||||
"""Displays experiment runs associated with this EvalTask."""
|
||||
if not self._experiment:
|
||||
raise ValueError("Experiment is not set.")
|
||||
elif IPython_display:
|
||||
IPython_display.display(
|
||||
vertexai.preview.get_experiment_df(self._experiment)
|
||||
)
|
||||
@@ -0,0 +1,324 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Utility functions for metrics."""
|
||||
|
||||
import datetime
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview.evaluation import metrics
|
||||
from vertexai.preview.evaluation import prompt_template
|
||||
from vertexai.preview.evaluation import utils
|
||||
from vertexai.preview.evaluation.metrics import _schema
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config,
|
||||
)
|
||||
from vertexai import generative_models
|
||||
from vertexai.generative_models import _generative_models
|
||||
import jsonschema
|
||||
import ruamel.yaml
|
||||
from ruamel.yaml import scalarstring
|
||||
|
||||
|
||||
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
|
||||
GenerativeModel = generative_models.GenerativeModel
|
||||
CustomOutputConfig = custom_output_config.CustomOutputConfig
|
||||
PairwiseMetric = metrics.PairwiseMetric
|
||||
PointwiseMetric = metrics.PointwiseMetric
|
||||
RubricBasedMetric = metrics.RubricBasedMetric
|
||||
RubricGenerationConfig = metrics.RubricGenerationConfig
|
||||
PromptTemplate = prompt_template.PromptTemplate
|
||||
|
||||
# Initialize schema validator.
|
||||
_schema = ruamel.yaml.YAML(typ="safe").load(_schema.AUTORATER_METRIC_SCHEMA)
|
||||
_schema_validator = jsonschema.Draft202012Validator(schema=_schema)
|
||||
|
||||
|
||||
def dump(
|
||||
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
|
||||
file_path: str,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""Dumps a metric object to a YAML file.
|
||||
|
||||
Args:
|
||||
metric: The metric to be dumped to a file.
|
||||
file_path: The path to the file. Local and GCS files are supported.
|
||||
version: Optional. The version of the metric. Defaults to the timestamp
|
||||
when the metric file is created.
|
||||
"""
|
||||
yaml_data = dumps(metric, version)
|
||||
if file_path.startswith(utils._GCS_PREFIX):
|
||||
utils._upload_string_to_gcs(file_path, yaml_data)
|
||||
else:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(yaml_data)
|
||||
|
||||
|
||||
def dumps(
|
||||
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
|
||||
version: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Dumps a metric object to YAML data.
|
||||
|
||||
Args:
|
||||
metric: The metric to be dumped to YAML data.
|
||||
version: Optional. The version of the metric. Defaults to the timestamp
|
||||
when the metric file is created.
|
||||
|
||||
Returns:
|
||||
The YAML data of the metric.
|
||||
"""
|
||||
steps = []
|
||||
metric_name = None
|
||||
if isinstance(metric, PointwiseMetric) or isinstance(metric, PairwiseMetric):
|
||||
metric_name = metric.metric_name
|
||||
steps.append(_dump_metric(metric))
|
||||
elif isinstance(metric, RubricBasedMetric):
|
||||
metric_name = metric.critique_metric.metric_name
|
||||
steps.append(_dump_rubric(metric.generation_config))
|
||||
steps.append(_dump_metric(metric.critique_metric))
|
||||
metadata = {
|
||||
"name": metric_name,
|
||||
"version": (
|
||||
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
if version is None
|
||||
else version
|
||||
),
|
||||
"required_inputs": _parse_required_inputs(metric),
|
||||
}
|
||||
metric_config = {
|
||||
"metadata": metadata,
|
||||
"steps": steps,
|
||||
}
|
||||
yaml = ruamel.yaml.YAML()
|
||||
yaml.indent(sequence=4, offset=2)
|
||||
with io.StringIO() as s:
|
||||
yaml.dump(metric_config, s)
|
||||
return s.getvalue()
|
||||
|
||||
|
||||
def _dump_metric(metric: Union[PointwiseMetric, PairwiseMetric]) -> Dict[str, Any]:
|
||||
"""Dumps a metric object to autorater metric schema."""
|
||||
output_type = None
|
||||
if metric.custom_output_config and metric.custom_output_config.return_raw_output:
|
||||
output_type = "raw"
|
||||
step = {
|
||||
"type": (
|
||||
"pairwise_metric"
|
||||
if isinstance(metric, PairwiseMetric)
|
||||
else "pointwise_metric"
|
||||
),
|
||||
"prompt": {
|
||||
"template": scalarstring.preserve_literal(metric.metric_prompt_template),
|
||||
},
|
||||
}
|
||||
if metric.system_instruction:
|
||||
step["prompt"]["system_instruction"] = metric.system_instruction
|
||||
if output_type:
|
||||
step["output"] = {
|
||||
"type": output_type,
|
||||
}
|
||||
if metric.autorater_config:
|
||||
step["model"] = {
|
||||
"model_name_or_endpoint": (metric.autorater_config.autorater_model),
|
||||
}
|
||||
options = {}
|
||||
if metric.autorater_config.flip_enabled:
|
||||
options["flip_enabled"] = metric.autorater_config.flip_enabled
|
||||
if metric.autorater_config.sampling_count:
|
||||
options["sample_count"] = metric.autorater_config.sampling_count
|
||||
if options:
|
||||
step["options"] = options
|
||||
return step
|
||||
|
||||
|
||||
def _dump_rubric(generation_config: RubricGenerationConfig) -> Dict[str, Any]:
|
||||
"""Dumps a rubric generation config to autorater metric schema."""
|
||||
# TODO: b/396217889 - add support for custom output.
|
||||
step = {
|
||||
"type": "rubric",
|
||||
"prompt": {
|
||||
"template": scalarstring.preserve_literal(
|
||||
generation_config.prompt_template
|
||||
),
|
||||
},
|
||||
}
|
||||
if generation_config.model and isinstance(generation_config.model, GenerativeModel):
|
||||
step["model"] = {
|
||||
"model_name_or_endpoint": generation_config.model._model_name,
|
||||
}
|
||||
return step
|
||||
|
||||
|
||||
def _parse_required_inputs(
|
||||
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
|
||||
) -> List[str]:
|
||||
"""Parses required inputs from a metric object."""
|
||||
if isinstance(metric, PointwiseMetric) or isinstance(metric, PairwiseMetric):
|
||||
return list(PromptTemplate(metric.metric_prompt_template).variables)
|
||||
elif isinstance(metric, RubricBasedMetric):
|
||||
met = PromptTemplate(metric.critique_metric.metric_prompt_template).variables
|
||||
gen = PromptTemplate(metric.generation_config.prompt_template).variables
|
||||
return list(met.union(gen))
|
||||
else:
|
||||
raise ValueError(f"Unsupported metric type: {type(metric)}")
|
||||
|
||||
|
||||
def load(
|
||||
file_path: str,
|
||||
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
|
||||
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
|
||||
"""Loads a metric object from a YAML file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the autorater metric configuration.
|
||||
Local and GCS files are supported.
|
||||
baseline_model: Optional. The baseline model to use for pairwise metrics.
|
||||
|
||||
Returns:
|
||||
The metric object loaded from the file.
|
||||
"""
|
||||
if file_path.startswith(utils._GCS_PREFIX):
|
||||
file_contents = utils._read_gcs_file_contents(file_path)
|
||||
return loads(file_contents, baseline_model)
|
||||
with open(file_path, "r") as f:
|
||||
return loads(f.read(), baseline_model)
|
||||
|
||||
|
||||
def loads(
|
||||
yaml_data: str,
|
||||
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
|
||||
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
|
||||
"""Loads a metric object from YAML data.
|
||||
|
||||
Args:
|
||||
yaml_data: YAML data containing the autorater metric configuration.
|
||||
baseline_model: Optional. The baseline model to use for pairwise metrics.
|
||||
|
||||
Returns:
|
||||
The metric object loaded from the YAML data.
|
||||
"""
|
||||
yaml = ruamel.yaml.YAML(typ="safe")
|
||||
yaml_obj = yaml.load(yaml_data)
|
||||
try:
|
||||
_schema_validator.validate(yaml_obj)
|
||||
except jsonschema.exceptions.ValidationError as e:
|
||||
raise ValueError(
|
||||
f"Invalid autorater metric config: {e.message} for {e.path.pop()}"
|
||||
) from e
|
||||
metadata = yaml_obj["metadata"]
|
||||
steps = yaml_obj["steps"]
|
||||
required_inputs = set(metadata["required_inputs"])
|
||||
metric = None
|
||||
rubric = None
|
||||
for step in steps:
|
||||
_validate_template(step["prompt"]["template"], required_inputs)
|
||||
model_name = None
|
||||
flip = None
|
||||
sampling = None
|
||||
if "model" in step:
|
||||
model_name = _parse_model_name(step["model"]["model_name_or_endpoint"])
|
||||
if "options" in step:
|
||||
flip = step["options"].get("flip_enabled", False)
|
||||
sampling = step["options"].get("sample_count", 1)
|
||||
autorater = None
|
||||
if model_name:
|
||||
autorater = AutoraterConfig(
|
||||
autorater_model=model_name,
|
||||
flip_enabled=flip,
|
||||
sampling_count=sampling,
|
||||
)
|
||||
system_instruction = step["prompt"].get("system_instruction")
|
||||
custom_output = None
|
||||
if "output" in step and step["output"]["type"] == "raw":
|
||||
custom_output = CustomOutputConfig(return_raw_output=True)
|
||||
if step["type"] == "pointwise_metric":
|
||||
if metric is not None:
|
||||
raise ValueError("Only one metric step is supported.")
|
||||
if baseline_model:
|
||||
raise ValueError("Baseline model provided for pointwise metric.")
|
||||
metric = PointwiseMetric(
|
||||
metric=metadata["name"],
|
||||
metric_prompt_template=step["prompt"]["template"],
|
||||
system_instruction=system_instruction,
|
||||
autorater_config=autorater,
|
||||
custom_output_config=custom_output,
|
||||
)
|
||||
elif step["type"] == "pairwise_metric":
|
||||
if metric is not None:
|
||||
raise ValueError("Only one metric step is supported.")
|
||||
metric = PairwiseMetric(
|
||||
metric=metadata["name"],
|
||||
metric_prompt_template=step["prompt"]["template"],
|
||||
system_instruction=system_instruction,
|
||||
baseline_model=baseline_model,
|
||||
autorater_config=autorater,
|
||||
custom_output_config=custom_output,
|
||||
)
|
||||
elif step["type"] == "rubric":
|
||||
if rubric is not None:
|
||||
raise ValueError("Only one rubric step is supported.")
|
||||
model = None
|
||||
if model_name:
|
||||
model = generative_models.GenerativeModel(model_name=model_name)
|
||||
rubric = RubricGenerationConfig(
|
||||
prompt_template=step["prompt"]["template"],
|
||||
model=model,
|
||||
)
|
||||
if metric is None:
|
||||
raise ValueError("A metric step must be provided.")
|
||||
if rubric is not None:
|
||||
return RubricBasedMetric(
|
||||
generation_config=rubric,
|
||||
critique_metric=metric,
|
||||
)
|
||||
return metric
|
||||
|
||||
|
||||
def _parse_model_name(model_name_or_endpoint: str) -> str:
|
||||
"""Parses model name or endpoint.
|
||||
|
||||
Args:
|
||||
model_name_or_endpoint: Model Garden model name or tuned model endpoint
|
||||
resource name can be provided.
|
||||
|
||||
Returns:
|
||||
The model resource name.
|
||||
"""
|
||||
project = initializer.global_config.project
|
||||
location = initializer.global_config.location
|
||||
model_name = _generative_models._reconcile_model_name(
|
||||
model_name_or_endpoint, project, location
|
||||
)
|
||||
return _generative_models._get_resource_name_from_model_name(
|
||||
model_name, project, location
|
||||
)
|
||||
|
||||
|
||||
def _validate_template(template: str, required_inputs: List[str]) -> None:
|
||||
"""Validates the template contains only required inputs."""
|
||||
placeholders = PromptTemplate(template).variables
|
||||
if not placeholders.issubset(required_inputs):
|
||||
raise ValueError(
|
||||
"Template contains placeholders that are not in required inputs:"
|
||||
f" {placeholders - required_inputs}"
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Evaluation Metrics Module."""
|
||||
|
||||
from vertexai.preview.evaluation.metrics import _base
|
||||
from vertexai.preview.evaluation.metrics import _rouge
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_trajectory_single_tool_use,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
metric_prompt_template,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
metric_prompt_template_examples,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
predefined_rubric_metrics,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
rubric_based_metric,
|
||||
)
|
||||
|
||||
|
||||
PairwiseMetric = pairwise_metric.PairwiseMetric
|
||||
PointwiseMetric = pointwise_metric.PointwiseMetric
|
||||
CustomMetric = _base.CustomMetric
|
||||
PairwiseMetricPromptTemplate = metric_prompt_template.PairwiseMetricPromptTemplate
|
||||
PointwiseMetricPromptTemplate = metric_prompt_template.PointwiseMetricPromptTemplate
|
||||
MetricPromptTemplateExamples = (
|
||||
metric_prompt_template_examples.MetricPromptTemplateExamples
|
||||
)
|
||||
Rouge = _rouge.Rouge
|
||||
TrajectorySingleToolUse = _trajectory_single_tool_use.TrajectorySingleToolUse
|
||||
CustomOutputConfig = custom_output_config.CustomOutputConfig
|
||||
RubricBasedMetric = rubric_based_metric.RubricBasedMetric
|
||||
RubricGenerationConfig = _base.RubricGenerationConfig
|
||||
PredefinedRubricMetrics = predefined_rubric_metrics.PredefinedRubricMetrics
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CustomMetric",
|
||||
"PairwiseMetric",
|
||||
"PointwiseMetric",
|
||||
"PairwiseMetricPromptTemplate",
|
||||
"PointwiseMetricPromptTemplate",
|
||||
"MetricPromptTemplateExamples",
|
||||
"Rouge",
|
||||
"TrajectorySingleToolUse",
|
||||
"CustomOutputConfig",
|
||||
"RubricBasedMetric",
|
||||
"RubricGenerationConfig",
|
||||
"PredefinedRubricMetrics",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Base classes for evaluation metrics."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union, List
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config as custom_output_config_class,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
metric_prompt_template as metric_prompt_template_base,
|
||||
)
|
||||
|
||||
|
||||
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
|
||||
|
||||
|
||||
class _Metric(abc.ABC):
|
||||
"""The abstract class for evaluation metric."""
|
||||
|
||||
def __init__(self, metric: str):
|
||||
self._metric = metric
|
||||
|
||||
def __str__(self):
|
||||
return self.metric_name
|
||||
|
||||
@property
|
||||
def metric_name(self) -> str:
|
||||
return self._metric
|
||||
|
||||
|
||||
class _ModelBasedMetric(_Metric):
|
||||
"""A Model-based Metric.
|
||||
|
||||
An evaluation metric that evaluates generative AI model responses with
|
||||
another generative model as a judge. This metric can be used to evaluate a
|
||||
single model, or two models side-by-side.
|
||||
|
||||
For more details on when to use model-based metrics, see
|
||||
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metric: str,
|
||||
metric_prompt_template: Union[
|
||||
metric_prompt_template_base.PointwiseMetricPromptTemplate,
|
||||
metric_prompt_template_base.PairwiseMetricPromptTemplate,
|
||||
str,
|
||||
],
|
||||
system_instruction: Optional[str] = None,
|
||||
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
|
||||
custom_output_config: Optional[
|
||||
custom_output_config_class.CustomOutputConfig
|
||||
] = None,
|
||||
):
|
||||
"""Initializes the model-based evaluation metric.
|
||||
|
||||
Args:
|
||||
metric: Generic model based metric name.
|
||||
metric_prompt_template: A metric prompt template for performing
|
||||
the model-based evaluation. A freeform string is also accepted.
|
||||
system_instruction: The system instruction to be used in the metric
|
||||
prompt.
|
||||
autorater_config: The config for judge model.
|
||||
custom_output_config: Config for custom output from the judge model.
|
||||
"""
|
||||
super().__init__(metric=metric)
|
||||
self.metric_prompt_template = str(metric_prompt_template)
|
||||
self.system_instruction = system_instruction
|
||||
self.autorater_config = autorater_config
|
||||
self.custom_output_config = custom_output_config
|
||||
|
||||
|
||||
class CustomMetric(_Metric):
|
||||
"""The custom evaluation metric.
|
||||
|
||||
A fully-customized CustomMetric that can be used to evaluate a single model
|
||||
by defining a metric function for a computation-based metric. The
|
||||
CustomMetric is computed on the client-side using the user-defined metric
|
||||
function in SDK only, not by the Vertex Gen AI Evaluation Service.
|
||||
|
||||
Attributes:
|
||||
name: The name of the metric.
|
||||
metric_function: The user-defined evaluation function to compute a metric
|
||||
score. Must use the dataset row dictionary as the metric function
|
||||
input and return per-instance metric result as a dictionary output.
|
||||
The metric score must mapped to the name of the CustomMetric as key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
metric_function: Callable[
|
||||
[Dict[str, Any]],
|
||||
Dict[str, Any],
|
||||
],
|
||||
):
|
||||
"""Initializes the evaluation metric."""
|
||||
super().__init__(name)
|
||||
self.name = name
|
||||
self.metric_function = metric_function
|
||||
|
||||
|
||||
class _AutomaticMetric(_Metric):
|
||||
"""An automatic metric that computes deterministic score based on reference.
|
||||
|
||||
An lexicon-based evaluation metric that evaluate a generative model's
|
||||
response on the given evaluation task with reference ground truth answers.
|
||||
It is a type of pointwise evaluation metric.
|
||||
|
||||
For more details on when to use automatic metrics, see
|
||||
[Evaluation methods and
|
||||
metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metric: Literal[constants.Metric.ROUGE],
|
||||
):
|
||||
"""Initializes the automatic evaluation metric.
|
||||
|
||||
Args:
|
||||
metric: The automatic evaluation metric name.
|
||||
"""
|
||||
super().__init__(metric=metric)
|
||||
|
||||
|
||||
class RubricGenerationConfig:
|
||||
"""The rubric generation config."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: str,
|
||||
model: Optional[_ModelType] = None,
|
||||
parsing_fn: Optional[Callable[[str], List[str]]] = None,
|
||||
):
|
||||
"""Initializes the rubric generation config.
|
||||
|
||||
Args:
|
||||
prompt_template: The prompt template for rubric generation.
|
||||
model: The model to use for rubric generation.
|
||||
parsing_fn: The function to parse the rubric generation response.
|
||||
"""
|
||||
self.prompt_template = prompt_template
|
||||
self.model = model
|
||||
self.parsing_fn = parsing_fn
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,802 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Library for metrics computation with Gen AI Evaluation Service."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from google import api_core
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform_v1beta1.services import (
|
||||
evaluation_service as gapic_evaluation_services,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview.evaluation import _base as eval_base
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation import multimodal_utils
|
||||
from vertexai.preview.evaluation import (
|
||||
prompt_template as prompt_template_base,
|
||||
)
|
||||
from vertexai.preview.evaluation import utils
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_base as metrics_base,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_default_templates,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import _rouge
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_trajectory_single_tool_use,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config as custom_output_config_class,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
from google.protobuf import json_format
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_METRIC_NAME_TO_METRIC_SPEC = {
|
||||
# Automatic Metrics.
|
||||
constants.Metric.EXACT_MATCH: (gapic_eval_service_types.ExactMatchSpec()),
|
||||
constants.Metric.BLEU: gapic_eval_service_types.BleuSpec(),
|
||||
constants.Metric.ROUGE: gapic_eval_service_types.RougeSpec(),
|
||||
constants.Metric.ROUGE_1: gapic_eval_service_types.RougeSpec(rouge_type="rouge1"),
|
||||
constants.Metric.ROUGE_2: gapic_eval_service_types.RougeSpec(rouge_type="rouge2"),
|
||||
constants.Metric.ROUGE_L: gapic_eval_service_types.RougeSpec(rouge_type="rougeL"),
|
||||
constants.Metric.ROUGE_L_SUM: gapic_eval_service_types.RougeSpec(
|
||||
rouge_type="rougeLsum"
|
||||
),
|
||||
constants.Metric.TOOL_CALL_VALID: (gapic_eval_service_types.ToolCallValidSpec()),
|
||||
constants.Metric.TOOL_NAME_MATCH: (gapic_eval_service_types.ToolNameMatchSpec()),
|
||||
constants.Metric.TOOL_PARAMETER_KV_MATCH: (
|
||||
gapic_eval_service_types.ToolParameterKVMatchSpec()
|
||||
),
|
||||
constants.Metric.TOOL_PARAMETER_KEY_MATCH: (
|
||||
gapic_eval_service_types.ToolParameterKeyMatchSpec()
|
||||
),
|
||||
# Pointwise Metrics.
|
||||
constants.Metric.POINTWISE_METRIC: (gapic_eval_service_types.PointwiseMetricSpec()),
|
||||
# Pairwise Metrics.
|
||||
constants.Metric.PAIRWISE_METRIC: (gapic_eval_service_types.PairwiseMetricSpec()),
|
||||
constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING: (
|
||||
gapic_eval_service_types.RubricBasedInstructionFollowingSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_EXACT_MATCH: (
|
||||
gapic_eval_service_types.TrajectoryExactMatchSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_IN_ORDER_MATCH: (
|
||||
gapic_eval_service_types.TrajectoryInOrderMatchSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_ANY_ORDER_MATCH: (
|
||||
gapic_eval_service_types.TrajectoryAnyOrderMatchSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_PRECISION: (
|
||||
gapic_eval_service_types.TrajectoryPrecisionSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_RECALL: (
|
||||
gapic_eval_service_types.TrajectoryRecallSpec()
|
||||
),
|
||||
constants.Metric.TRAJECTORY_SINGLE_TOOL_USE: (
|
||||
gapic_eval_service_types.TrajectorySingleToolUseSpec()
|
||||
),
|
||||
}
|
||||
_QUESTION_TEMPLATE = """<question>{question}"""
|
||||
|
||||
|
||||
def _format_rubrics(questions: List[str]) -> str:
|
||||
"""Formats the list of rubrics into a question block."""
|
||||
question_block = "\n".join(
|
||||
_QUESTION_TEMPLATE.format(question=q.strip()) for q in questions
|
||||
)
|
||||
return question_block
|
||||
|
||||
|
||||
def build_custom_output_format_config(
|
||||
custom_output_config: custom_output_config_class.CustomOutputConfig,
|
||||
) -> Union[gapic_eval_service_types.CustomOutputFormatConfig, None]:
|
||||
"""Builds a CustomOutputFormatConfig from user input."""
|
||||
custom_output_cfg = gapic_eval_service_types.CustomOutputFormatConfig()
|
||||
if custom_output_config.return_raw_output:
|
||||
custom_output_cfg.return_raw_output = True
|
||||
return custom_output_cfg
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def build_trajectory(
|
||||
trajectory: Union[str, List[Dict[str, Any]]],
|
||||
) -> gapic_eval_service_types.Trajectory:
|
||||
"""Builds a trajectory from user input."""
|
||||
if not trajectory:
|
||||
return
|
||||
|
||||
if isinstance(trajectory, str):
|
||||
trajectory = json.loads(trajectory)
|
||||
|
||||
if isinstance(trajectory, List):
|
||||
try:
|
||||
tool_calls = []
|
||||
for tool_call_dict in trajectory:
|
||||
tool_input_str = json.dumps(tool_call_dict["tool_input"])
|
||||
tool_calls.append(
|
||||
gapic_eval_service_types.ToolCall(
|
||||
tool_name=tool_call_dict["tool_name"], tool_input=tool_input_str
|
||||
)
|
||||
)
|
||||
return gapic_eval_service_types.Trajectory(tool_calls=tool_calls)
|
||||
except KeyError as e:
|
||||
_LOGGER.error(f"Failed to parse trajectory: {e}")
|
||||
else:
|
||||
_LOGGER.error(
|
||||
f"Unsupported trajectory type: {type(trajectory)}, expected list or"
|
||||
" a JSON array."
|
||||
)
|
||||
|
||||
|
||||
def build_request(
|
||||
metric: Union[str, metrics_base._Metric],
|
||||
row_dict: Dict[str, Any],
|
||||
evaluation_run_config: eval_base.EvaluationRunConfig,
|
||||
) -> gapic_eval_service_types.EvaluateInstancesRequest:
|
||||
"""Builds a metric instance and form the request for the evaluation service.
|
||||
|
||||
Args:
|
||||
metric: The name of the metric to evaluate.
|
||||
row_dict: An evaluation dataset instance as a dictionary.
|
||||
evaluation_run_config: Evaluation run configurations.
|
||||
|
||||
Returns:
|
||||
A single EvaluateInstancesRequest.
|
||||
|
||||
Raises:
|
||||
ValueError: If required request fields are not provided.
|
||||
"""
|
||||
project = initializer.global_config.project
|
||||
location = initializer.global_config.location
|
||||
if not project or not location:
|
||||
raise ValueError(
|
||||
"No project or location specified. Please run `vertexai.init()` to"
|
||||
" provide these parameters."
|
||||
)
|
||||
location_path = (
|
||||
gapic_evaluation_services.EvaluationServiceClient.common_location_path(
|
||||
project, location
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(metric, pointwise_metric.PointwiseMetric):
|
||||
metric_name = constants.Metric.POINTWISE_METRIC
|
||||
elif isinstance(metric, pairwise_metric.PairwiseMetric):
|
||||
metric_name = constants.Metric.PAIRWISE_METRIC
|
||||
else:
|
||||
metric_name = str(metric)
|
||||
|
||||
try:
|
||||
metric_spec = _METRIC_NAME_TO_METRIC_SPEC[metric_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Metric name: {metric_name} is not supported.") from e
|
||||
|
||||
model_based_metric_instance_input = {}
|
||||
metric_column_mapping = evaluation_run_config.metric_column_mapping
|
||||
if isinstance(
|
||||
metric, metrics_base._ModelBasedMetric # pylint: disable=protected-access
|
||||
):
|
||||
metric_spec.metric_prompt_template = metric.metric_prompt_template
|
||||
metric_spec.system_instruction = metric.system_instruction
|
||||
if metric.custom_output_config:
|
||||
metric_spec.custom_output_format_config = build_custom_output_format_config(
|
||||
metric.custom_output_config
|
||||
)
|
||||
for variable in prompt_template_base.PromptTemplate(
|
||||
metric.metric_prompt_template
|
||||
).variables:
|
||||
model_based_metric_instance_input[variable] = row_dict.get(
|
||||
metric_column_mapping.get(variable),
|
||||
"",
|
||||
)
|
||||
if isinstance(metric, pairwise_metric.PairwiseMetric):
|
||||
metric_column_mapping = evaluation_run_config.metric_column_mapping
|
||||
metric_spec.candidate_response_field_name = metric_column_mapping.get(
|
||||
constants.Dataset.MODEL_RESPONSE_COLUMN,
|
||||
constants.Dataset.MODEL_RESPONSE_COLUMN,
|
||||
)
|
||||
metric_spec.baseline_response_field_name = metric_column_mapping.get(
|
||||
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
|
||||
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
|
||||
)
|
||||
elif isinstance(metric, _rouge.Rouge):
|
||||
metric_spec.rouge_type = metric.rouge_type
|
||||
metric_spec.use_stemmer = metric.use_stemmer
|
||||
metric_spec.split_summaries = metric.split_summaries
|
||||
elif isinstance(metric, _trajectory_single_tool_use.TrajectorySingleToolUse):
|
||||
metric_spec.tool_name = metric.tool_name
|
||||
|
||||
response = row_dict.get(
|
||||
metric_column_mapping.get(constants.Dataset.MODEL_RESPONSE_COLUMN), ""
|
||||
)
|
||||
reference = row_dict.get(
|
||||
metric_column_mapping.get(constants.Dataset.REFERENCE_COLUMN), ""
|
||||
)
|
||||
predicted_trajectory = build_trajectory(
|
||||
row_dict.get(
|
||||
metric_column_mapping.get(constants.Dataset.PREDICTED_TRAJECTORY_COLUMN),
|
||||
"",
|
||||
)
|
||||
)
|
||||
reference_trajectory = build_trajectory(
|
||||
row_dict.get(
|
||||
metric_column_mapping.get(constants.Dataset.REFERENCE_TRAJECTORY_COLUMN),
|
||||
"",
|
||||
)
|
||||
)
|
||||
if isinstance(metric, metrics_base._ModelBasedMetric):
|
||||
if metric_spec.metric_prompt_template in (
|
||||
_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
_default_templates.TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
_default_templates.PAIRWISE_MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
_default_templates.PAIRWISE_TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
):
|
||||
model_based_metric_instance_input[
|
||||
constants.Dataset.RUBRICS_COLUMN
|
||||
] = _format_rubrics(
|
||||
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN]
|
||||
)
|
||||
if (
|
||||
constants.Dataset.RUBRICS_COLUMN in model_based_metric_instance_input
|
||||
and isinstance(
|
||||
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN],
|
||||
List,
|
||||
)
|
||||
):
|
||||
model_based_metric_instance_input[
|
||||
constants.Dataset.RUBRICS_COLUMN
|
||||
] = "\n".join(
|
||||
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN]
|
||||
)
|
||||
|
||||
if metric_name == constants.Metric.EXACT_MATCH:
|
||||
instance = gapic_eval_service_types.ExactMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.ExactMatchInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
exact_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.BLEU:
|
||||
instance = gapic_eval_service_types.BleuInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.BleuInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
bleu_input=instance,
|
||||
)
|
||||
elif metric_name in (
|
||||
constants.Metric.ROUGE,
|
||||
constants.Metric.ROUGE_1,
|
||||
constants.Metric.ROUGE_2,
|
||||
constants.Metric.ROUGE_L,
|
||||
constants.Metric.ROUGE_L_SUM,
|
||||
):
|
||||
instance = gapic_eval_service_types.RougeInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.RougeInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
rouge_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TOOL_CALL_VALID:
|
||||
instance = gapic_eval_service_types.ToolCallValidInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.ToolCallValidInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
tool_call_valid_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TOOL_NAME_MATCH:
|
||||
instance = gapic_eval_service_types.ToolNameMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.ToolNameMatchInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
tool_name_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TOOL_PARAMETER_KEY_MATCH:
|
||||
instance = gapic_eval_service_types.ToolParameterKeyMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.ToolParameterKeyMatchInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
tool_parameter_key_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TOOL_PARAMETER_KV_MATCH:
|
||||
instance = gapic_eval_service_types.ToolParameterKVMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.ToolParameterKVMatchInstance(
|
||||
prediction=response,
|
||||
reference=reference,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
tool_parameter_kv_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.POINTWISE_METRIC:
|
||||
if multimodal_utils.is_multimodal_instance(model_based_metric_instance_input):
|
||||
instance = gapic_eval_service_types.PointwiseMetricInput(
|
||||
metric_spec=metric_spec,
|
||||
instance=gapic_eval_service_types.PointwiseMetricInstance(
|
||||
content_map_instance=multimodal_utils.convert_multimodal_response_to_content_map(
|
||||
model_based_metric_instance_input
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
instance = gapic_eval_service_types.PointwiseMetricInput(
|
||||
metric_spec=metric_spec,
|
||||
instance=gapic_eval_service_types.PointwiseMetricInstance(
|
||||
json_instance=json.dumps(model_based_metric_instance_input),
|
||||
),
|
||||
)
|
||||
autorater_config = evaluation_run_config.autorater_config
|
||||
if (
|
||||
isinstance(metric, metrics_base._ModelBasedMetric)
|
||||
and metric.autorater_config
|
||||
):
|
||||
autorater_config = metric.autorater_config
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
pointwise_metric_input=instance,
|
||||
autorater_config=autorater_config,
|
||||
)
|
||||
elif metric_name == constants.Metric.PAIRWISE_METRIC:
|
||||
if multimodal_utils.is_multimodal_instance(model_based_metric_instance_input):
|
||||
instance = gapic_eval_service_types.PairwiseMetricInput(
|
||||
metric_spec=metric_spec,
|
||||
instance=gapic_eval_service_types.PairwiseMetricInstance(
|
||||
content_map_instance=multimodal_utils.convert_multimodal_response_to_content_map(
|
||||
model_based_metric_instance_input
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
instance = gapic_eval_service_types.PairwiseMetricInput(
|
||||
metric_spec=metric_spec,
|
||||
instance=gapic_eval_service_types.PairwiseMetricInstance(
|
||||
json_instance=json.dumps(model_based_metric_instance_input),
|
||||
),
|
||||
)
|
||||
autorater_config = evaluation_run_config.autorater_config
|
||||
if (
|
||||
isinstance(metric, metrics_base._ModelBasedMetric)
|
||||
and metric.autorater_config
|
||||
):
|
||||
autorater_config = metric.autorater_config
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
pairwise_metric_input=instance,
|
||||
autorater_config=autorater_config,
|
||||
)
|
||||
elif metric_name == constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING:
|
||||
required_rbif_fields = [
|
||||
constants.Dataset.MODEL_RESPONSE_COLUMN,
|
||||
constants.Dataset.PROMPT_COLUMN,
|
||||
]
|
||||
for field in required_rbif_fields:
|
||||
column_name = metric_column_mapping.get(field)
|
||||
value = row_dict.get(column_name)
|
||||
if value is None and field in required_rbif_fields:
|
||||
raise ValueError(
|
||||
f"Missing required field: `{field}` for "
|
||||
f"{constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING}."
|
||||
)
|
||||
else:
|
||||
model_based_metric_instance_input[field] = value
|
||||
instance = gapic_eval_service_types.RubricBasedInstructionFollowingInput(
|
||||
metric_spec=metric_spec,
|
||||
instance=gapic_eval_service_types.RubricBasedInstructionFollowingInstance(
|
||||
json_instance=json.dumps(model_based_metric_instance_input),
|
||||
),
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
rubric_based_instruction_following_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_EXACT_MATCH:
|
||||
instance = gapic_eval_service_types.TrajectoryExactMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectoryExactMatchInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
reference_trajectory=reference_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_exact_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_IN_ORDER_MATCH:
|
||||
instance = gapic_eval_service_types.TrajectoryInOrderMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectoryInOrderMatchInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
reference_trajectory=reference_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_in_order_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_ANY_ORDER_MATCH:
|
||||
instance = gapic_eval_service_types.TrajectoryAnyOrderMatchInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectoryAnyOrderMatchInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
reference_trajectory=reference_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_any_order_match_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_PRECISION:
|
||||
instance = gapic_eval_service_types.TrajectoryPrecisionInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectoryPrecisionInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
reference_trajectory=reference_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_precision_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_RECALL:
|
||||
instance = gapic_eval_service_types.TrajectoryRecallInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectoryRecallInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
reference_trajectory=reference_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_recall_input=instance,
|
||||
)
|
||||
elif metric_name == constants.Metric.TRAJECTORY_SINGLE_TOOL_USE:
|
||||
instance = gapic_eval_service_types.TrajectorySingleToolUseInput(
|
||||
metric_spec=metric_spec,
|
||||
instances=[
|
||||
gapic_eval_service_types.TrajectorySingleToolUseInstance(
|
||||
predicted_trajectory=predicted_trajectory,
|
||||
)
|
||||
],
|
||||
)
|
||||
return gapic_eval_service_types.EvaluateInstancesRequest(
|
||||
location=location_path,
|
||||
trajectory_single_tool_use_input=instance,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {metric_name}")
|
||||
|
||||
|
||||
def _parse_autometric_results(
|
||||
metric_result_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Parses the automatic metric results from the evaluation results.
|
||||
|
||||
Args:
|
||||
metric_result_dict: The metric results dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary containing metric score of the metric.
|
||||
"""
|
||||
for value in metric_result_dict.values():
|
||||
return {
|
||||
constants.MetricResult.SCORE_KEY: value[0].get(
|
||||
constants.MetricResult.SCORE_KEY
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _parse_pointwise_results(
|
||||
metric_result_dict: Dict[str, Any],
|
||||
metric: Union[str, metrics_base._Metric],
|
||||
) -> Dict[str, Any]:
|
||||
"""Parses the model-based pointwise metric results from the evaluation results.
|
||||
|
||||
Args:
|
||||
metric_result_dict: The metric results dictionary.
|
||||
metric: The metric to evaluate.
|
||||
|
||||
Returns:
|
||||
One of the following:
|
||||
1. A dictionary containing raw outputs from the judge model if
|
||||
return_raw_output is set to True in custom_output_config.
|
||||
2. A dictionary containing metric score and explanation of the
|
||||
metric if custom_output_config is not set.
|
||||
"""
|
||||
if (
|
||||
isinstance(metric, pointwise_metric.PointwiseMetric)
|
||||
and getattr(metric, "custom_output_config", None)
|
||||
and getattr(metric.custom_output_config, "return_raw_output", False)
|
||||
):
|
||||
raw_outputs = (
|
||||
metric_result_dict.get(constants.MetricResult.CUSTOM_OUTPUT_KEY)
|
||||
.get(constants.MetricResult.RAW_OUTPUTS_KEY)
|
||||
.get(constants.MetricResult.RAW_OUTPUT_KEY)
|
||||
)
|
||||
if (
|
||||
isinstance(metric, pointwise_metric.PointwiseMetric)
|
||||
and getattr(metric, "custom_output_config", None)
|
||||
and getattr(metric.custom_output_config, "parsing_fn", None)
|
||||
):
|
||||
parsing_fn = metric.custom_output_config.parsing_fn
|
||||
return parsing_fn(raw_outputs)
|
||||
return {constants.MetricResult.RAW_OUTPUT_KEY: raw_outputs}
|
||||
else:
|
||||
return {
|
||||
constants.MetricResult.SCORE_KEY: metric_result_dict.get(
|
||||
constants.MetricResult.SCORE_KEY
|
||||
),
|
||||
constants.MetricResult.EXPLANATION_KEY: metric_result_dict.get(
|
||||
constants.MetricResult.EXPLANATION_KEY
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _parse_pairwise_results(
|
||||
metric_result_dict: Dict[str, Any],
|
||||
metric: Union[str, metrics_base._Metric],
|
||||
) -> Dict[str, Any]:
|
||||
"""Parses the pairwise metric results from the evaluation results.
|
||||
|
||||
Args:
|
||||
metric_result_dict: The metric results dictionary.
|
||||
metric: The metric to evaluate.
|
||||
|
||||
Returns:
|
||||
One of the following:
|
||||
1. A dictionary containing raw outputs from the judge model if
|
||||
return_raw_output is set to True in custom_output_config.
|
||||
2. A dictionary containing metric score and explanation of the
|
||||
metric if custom_output_config is not set.
|
||||
"""
|
||||
if (
|
||||
isinstance(metric, pairwise_metric.PairwiseMetric)
|
||||
and getattr(metric, "custom_output_config", None)
|
||||
and getattr(metric.custom_output_config, "return_raw_output", False)
|
||||
):
|
||||
raw_outputs = (
|
||||
metric_result_dict.get(constants.MetricResult.CUSTOM_OUTPUT_KEY)
|
||||
.get(constants.MetricResult.RAW_OUTPUTS_KEY)
|
||||
.get(constants.MetricResult.RAW_OUTPUT_KEY)
|
||||
)
|
||||
if (
|
||||
isinstance(metric, pairwise_metric.PairwiseMetric)
|
||||
and getattr(metric, "custom_output_config", None)
|
||||
and getattr(metric.custom_output_config, "parsing_fn", None)
|
||||
):
|
||||
parsing_fn = metric.custom_output_config.parsing_fn
|
||||
return parsing_fn(raw_outputs)
|
||||
return {constants.MetricResult.RAW_OUTPUT_KEY: raw_outputs}
|
||||
else:
|
||||
return {
|
||||
constants.MetricResult.PAIRWISE_CHOICE_KEY: metric_result_dict.get(
|
||||
constants.MetricResult.PAIRWISE_CHOICE_KEY,
|
||||
),
|
||||
constants.MetricResult.EXPLANATION_KEY: metric_result_dict.get(
|
||||
constants.MetricResult.EXPLANATION_KEY
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _parse_rubric_based_instruction_following_results(
|
||||
metric_result_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Parses the rubric-based instruction following metric results from the evaluation results.
|
||||
|
||||
Args:
|
||||
metric_result_dict: The metric results dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary containing a list of rubrics and corresponding verdicts and
|
||||
an overall instruction following score.
|
||||
"""
|
||||
rubric_critique_results = []
|
||||
for rc_result in metric_result_dict["rubric_critique_results"]:
|
||||
if "verdict" not in rc_result:
|
||||
rc_result["verdict"] = False # proto3 shows False bool as unset
|
||||
rubric_critique_results.append(
|
||||
{
|
||||
"rubric": rc_result["rubric"],
|
||||
"verdict": rc_result["verdict"],
|
||||
}
|
||||
)
|
||||
return {
|
||||
constants.MetricResult.RUBRIC_LEVEL_INSTRUCTION_FOLLOWING_KEY: (
|
||||
rubric_critique_results
|
||||
),
|
||||
constants.MetricResult.SCORE_KEY: (
|
||||
metric_result_dict.get(constants.MetricResult.SCORE_KEY)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def handle_response(
|
||||
response: Union[str, gapic_eval_service_types.EvaluateInstancesResponse],
|
||||
metric: Union[str, metrics_base._Metric],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Handles the response from the evaluation service.
|
||||
|
||||
Args:
|
||||
response: The response from the evaluation service.
|
||||
metric: The metric to evaluate to check the output type.
|
||||
|
||||
Returns:
|
||||
A parsed metric result dictionary, or an error message string.
|
||||
"""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
|
||||
metric_type = response._pb.WhichOneof( # pylint: disable=protected-access
|
||||
"evaluation_results"
|
||||
)
|
||||
|
||||
if metric_type == constants.MetricResult.EXACT_MATCH_RESULTS:
|
||||
metric_result = response.exact_match_results
|
||||
elif metric_type == constants.MetricResult.BLEU_RESULTS:
|
||||
metric_result = response.bleu_results
|
||||
elif metric_type == constants.MetricResult.ROUGE_RESULTS:
|
||||
metric_result = response.rouge_results
|
||||
elif metric_type == constants.MetricResult.TOOL_CALL_VALID_RESULTS:
|
||||
metric_result = response.tool_call_valid_results
|
||||
elif metric_type == constants.MetricResult.TOOL_NAME_MATCH_RESULTS:
|
||||
metric_result = response.tool_name_match_results
|
||||
elif metric_type == constants.MetricResult.TOOL_PARAMETER_KEY_MATCH_RESULTS:
|
||||
metric_result = response.tool_parameter_key_match_results
|
||||
elif metric_type == constants.MetricResult.TOOL_PARAMETER_KV_MATCH_RESULTS:
|
||||
metric_result = response.tool_parameter_kv_match_results
|
||||
elif metric_type == constants.MetricResult.POINTWISE_METRIC_RESULT:
|
||||
metric_result = response.pointwise_metric_result
|
||||
elif metric_type == constants.MetricResult.PAIRWISE_METRIC_RESULT:
|
||||
metric_result = response.pairwise_metric_result
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_EXACT_MATCH_RESULTS:
|
||||
metric_result = response.trajectory_exact_match_results
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_IN_ORDER_MATCH_RESULTS:
|
||||
metric_result = response.trajectory_in_order_match_results
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_ANY_ORDER_MATCH_RESULTS:
|
||||
metric_result = response.trajectory_any_order_match_results
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_PRECISION_RESULTS:
|
||||
metric_result = response.trajectory_precision_results
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_RECALL_RESULTS:
|
||||
metric_result = response.trajectory_recall_results
|
||||
elif metric_type == constants.MetricResult.TRAJECTORY_SINGLE_TOOL_USE_RESULTS:
|
||||
metric_result = response.trajectory_single_tool_use_results
|
||||
elif (
|
||||
metric_type == constants.MetricResult.RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT
|
||||
):
|
||||
metric_result = response.rubric_based_instruction_following_result
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {metric_type}")
|
||||
|
||||
metric_result_dict = json_format.MessageToDict(
|
||||
metric_result._pb, # pylint: disable=protected-access
|
||||
preserving_proto_field_name=True,
|
||||
)
|
||||
if metric_type in (constants.MetricResult.AUTOMATIC_METRIC_RESULTS_LIST):
|
||||
result = _parse_autometric_results(metric_result_dict)
|
||||
elif metric_type == constants.MetricResult.POINTWISE_METRIC_RESULT:
|
||||
result = _parse_pointwise_results(metric_result_dict, metric)
|
||||
elif metric_type == constants.MetricResult.PAIRWISE_METRIC_RESULT:
|
||||
result = _parse_pairwise_results(metric_result_dict, metric)
|
||||
elif (
|
||||
metric_type == constants.MetricResult.RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT
|
||||
):
|
||||
result = _parse_rubric_based_instruction_following_results(metric_result_dict)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {metric_type}")
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_instances(
|
||||
client: gapic_evaluation_services.EvaluationServiceClient,
|
||||
request: gapic_eval_service_types.EvaluateInstancesRequest,
|
||||
rate_limiter: utils.RateLimiter,
|
||||
retry_timeout: float,
|
||||
) -> gapic_eval_service_types.EvaluateInstancesResponse:
|
||||
"""Evaluates an instance using Vertex Gen AI Evaluation Service.
|
||||
|
||||
Args:
|
||||
client: The Vertex Gen AI evaluation service client for evaluation.
|
||||
request: An EvaluateInstancesRequest.
|
||||
rate_limiter: The rate limiter for evaluation service requests.
|
||||
retry_timeout: How long to keep retrying the evaluation requests, in seconds.
|
||||
|
||||
Returns:
|
||||
An EvaluateInstancesResponse from Vertex Gen AI Evaluation Service.
|
||||
"""
|
||||
rate_limiter.sleep_and_advance()
|
||||
return client.evaluate_instances(
|
||||
request=request,
|
||||
retry=api_core.retry.Retry(
|
||||
initial=0.250,
|
||||
maximum=90.0,
|
||||
multiplier=1.45,
|
||||
timeout=retry_timeout,
|
||||
predicate=api_core.retry.if_exception_type(
|
||||
api_core.exceptions.Aborted,
|
||||
api_core.exceptions.DeadlineExceeded,
|
||||
api_core.exceptions.ResourceExhausted,
|
||||
api_core.exceptions.ServiceUnavailable,
|
||||
api_core.exceptions.Cancelled,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""ROUGE Metric."""
|
||||
|
||||
from typing import Literal
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation.metrics import _base
|
||||
|
||||
|
||||
class Rouge(_base._AutomaticMetric): # pylint: disable=protected-access
|
||||
"""The ROUGE Metric.
|
||||
|
||||
Calculates the recall of n-grams in prediction as compared to reference and
|
||||
returns a score ranging between 0 and 1. Supported rouge types are
|
||||
rougen[1-9], rougeL, and rougeLsum.
|
||||
"""
|
||||
|
||||
_metric_name = constants.Metric.ROUGE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rouge_type: Literal[
|
||||
"rouge1",
|
||||
"rouge2",
|
||||
"rouge3",
|
||||
"rouge4",
|
||||
"rouge5",
|
||||
"rouge6",
|
||||
"rouge7",
|
||||
"rouge8",
|
||||
"rouge9",
|
||||
"rougeL",
|
||||
"rougeLsum",
|
||||
],
|
||||
use_stemmer: bool = False,
|
||||
split_summaries: bool = False
|
||||
):
|
||||
"""Initializes the ROUGE metric.
|
||||
|
||||
Args:
|
||||
rouge_type: Supported rouge types are rougen[1-9], rougeL, and rougeLsum.
|
||||
use_stemmer: Whether to use stemmer to compute rouge score.
|
||||
split_summaries: Whether to split summaries while using 'rougeLsum' to
|
||||
compute rouge score.
|
||||
"""
|
||||
self._rouge_type = rouge_type
|
||||
self._use_stemmer = use_stemmer
|
||||
self._split_summaries = split_summaries
|
||||
|
||||
super().__init__(
|
||||
metric=Rouge._metric_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def rouge_type(self) -> str:
|
||||
return self._rouge_type
|
||||
|
||||
@property
|
||||
def use_stemmer(self) -> bool:
|
||||
return self._use_stemmer
|
||||
|
||||
@property
|
||||
def split_summaries(self) -> bool:
|
||||
return self._split_summaries
|
||||
@@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Schema for autorater metric configuration."""
|
||||
|
||||
AUTORATER_METRIC_SCHEMA = """
|
||||
$schema: https://json-schema.org/draft/2020-12/schema
|
||||
title: AutoRater Metric Configuration
|
||||
description: A metric definition for model-based evaluation.
|
||||
type: object
|
||||
properties:
|
||||
metadata:
|
||||
description: Useful information about the metric.
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Name of the metric.
|
||||
type: string
|
||||
description:
|
||||
description: Description of the metric.
|
||||
type: string
|
||||
author:
|
||||
description: Author of the metric.
|
||||
type: string
|
||||
contact:
|
||||
description: PoC for the metric.
|
||||
type: string
|
||||
version:
|
||||
description: Version of the metric.
|
||||
type: string
|
||||
classification:
|
||||
description: Classification of the metric.
|
||||
type: string
|
||||
enum:
|
||||
- experimental
|
||||
- benchmarked
|
||||
- deprecated
|
||||
required_inputs:
|
||||
description: Input fields used in the metric prompt template.
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
uniqueItems: true
|
||||
benchmarks:
|
||||
description: List of benchmarks used for the metric.
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
dataset:
|
||||
description: Dataset used for benchmarking.
|
||||
type: string
|
||||
results:
|
||||
description: Results from benchmarking.
|
||||
type: string
|
||||
required:
|
||||
- results
|
||||
minItems: 1
|
||||
uniqueItems: true
|
||||
usage:
|
||||
description: Links to documentation or notebooks with example usage.
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
uniqueItems: true
|
||||
required:
|
||||
- name
|
||||
- version
|
||||
- required_inputs
|
||||
steps:
|
||||
description: List of steps used for the autorater workflow.
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
description: Type of the step.
|
||||
type: string
|
||||
enum:
|
||||
- pointwise_metric
|
||||
- pairwise_metric
|
||||
- rubric
|
||||
prompt:
|
||||
description: Prompt template for the step.
|
||||
type: object
|
||||
properties:
|
||||
system_instruction:
|
||||
description: System instruction for the model.
|
||||
type: string
|
||||
template:
|
||||
description: Template to populate with inputs from the dataset.
|
||||
type: string
|
||||
required:
|
||||
- template
|
||||
model:
|
||||
description: Configuration of the model for the step.
|
||||
type: object
|
||||
properties:
|
||||
model_name_or_endpoint:
|
||||
description: Name or endpoint of the model.
|
||||
type: string
|
||||
required:
|
||||
- model_name_or_endpoint
|
||||
options:
|
||||
description: Options for the step.
|
||||
type: object
|
||||
properties:
|
||||
sample_count:
|
||||
description: Number of samples for each instance in the dataset.
|
||||
type: integer
|
||||
flip_enabled:
|
||||
description: Whether to flip candidate and baseline responses.
|
||||
type: boolean
|
||||
output:
|
||||
description: Output of the step.
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
description: Type of the output.
|
||||
type: string
|
||||
enum:
|
||||
- raw
|
||||
required:
|
||||
- type
|
||||
required:
|
||||
- type
|
||||
- prompt
|
||||
minItems: 1
|
||||
uniqueItems: true
|
||||
required:
|
||||
- metadata
|
||||
- steps
|
||||
"""
|
||||
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation.metrics import _base
|
||||
|
||||
|
||||
class TrajectorySingleToolUse(
|
||||
_base._AutomaticMetric
|
||||
): # pylint: disable=protected-access
|
||||
"""The TrajectorySingleToolUse Metric.
|
||||
|
||||
Evaluates if a tool is present in the trajectory or not.
|
||||
"""
|
||||
|
||||
_metric_name = constants.Metric.TRAJECTORY_SINGLE_TOOL_USE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
):
|
||||
"""Initializes the TrajectorySingleToolUse metric.
|
||||
|
||||
Args:
|
||||
tool_name: name of the tool to check.
|
||||
"""
|
||||
self._tool_name = tool_name
|
||||
|
||||
super().__init__(
|
||||
metric=TrajectorySingleToolUse._metric_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def tool_name(self) -> str:
|
||||
return self._tool_name
|
||||
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Custom output config for model-based metrics."""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
|
||||
class CustomOutputConfig:
|
||||
"""Custom output config for model-based metrics.
|
||||
|
||||
Attributes:
|
||||
return_raw_output: Whether to return the raw output of the metric
|
||||
function.
|
||||
parsing_fn: Function to parse the raw output of the metric.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
return_raw_output: bool = False,
|
||||
parsing_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
|
||||
):
|
||||
"""Initializes CustomOutputConfig."""
|
||||
self.return_raw_output = return_raw_output
|
||||
self.parsing_fn = parsing_fn
|
||||
@@ -0,0 +1,395 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Metric prompt template classes for model-based metrics evaluation."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from vertexai.preview.evaluation import (
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_NEWLINE = "\n"
|
||||
|
||||
|
||||
def serialize_dict_in_order(elements: Optional[Dict[str, str]]):
|
||||
"""Serializes dictionary to ordered string value without brackets."""
|
||||
if elements is None:
|
||||
return ""
|
||||
return _NEWLINE.join(f"{key}: {value}" for key, value in sorted(elements.items()))
|
||||
|
||||
|
||||
class _MetricPromptTemplate(prompt_template.PromptTemplate):
|
||||
"""Metric prompt template for generic model-based metrics evaluation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
criteria: Dict[str, str],
|
||||
rating_rubric: Dict[str, str],
|
||||
input_variables: List[str],
|
||||
instruction: Optional[str] = None,
|
||||
evaluation_steps: Optional[Dict[str, str]] = None,
|
||||
metric_definition: Optional[str] = None,
|
||||
few_shot_examples: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initializes a metric prompt template."""
|
||||
self._input_variables = input_variables
|
||||
|
||||
self._instruction = instruction
|
||||
self._metric_definition = metric_definition
|
||||
self._criteria = criteria
|
||||
self._rating_rubric = rating_rubric
|
||||
self._evaluation_steps = evaluation_steps
|
||||
self._few_shot_examples = few_shot_examples
|
||||
|
||||
self.template = self.__str__()
|
||||
|
||||
@property
|
||||
def prompt_data(self) -> str:
|
||||
return self.template
|
||||
|
||||
|
||||
class PointwiseMetricPromptTemplate(_MetricPromptTemplate):
|
||||
"""Pointwise metric prompt template for pointwise model-based metrics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
criteria: Dict[str, str],
|
||||
rating_rubric: Dict[str, str],
|
||||
input_variables: Optional[List[str]] = None,
|
||||
instruction: Optional[str] = None,
|
||||
metric_definition: Optional[str] = None,
|
||||
evaluation_steps: Optional[Dict[str, str]] = None,
|
||||
few_shot_examples: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initializes a pointwise metric prompt template.
|
||||
|
||||
Args:
|
||||
criteria: The standards and measures used to evaluate the model
|
||||
responses. It is a dictionary of criterion names and criterion
|
||||
definitions.
|
||||
rating_rubric: A dictionary mapping of rating name and rating
|
||||
definition, used to assign ratings or scores based on specific
|
||||
criteria.
|
||||
input_variables: An optional list of input fields to use in the metric
|
||||
prompt template for generating model-based evaluation results. Model
|
||||
"response" column is included by default. If metric_column_mapping is
|
||||
provided, the mapping values of the input fields will be used to
|
||||
retrieve data from the evaluation dataset.
|
||||
instruction: The general instruction to the model that performs the
|
||||
evaluation. If not provided, a default pointwise metric instruction
|
||||
will be used.
|
||||
metric_definition: The optional metric definition. It is a string
|
||||
describing the metric to be evaluated at a high level. If not
|
||||
provided, this field will not be included in the prompt template.
|
||||
evaluation_steps: The optional gudelines of evaluation steps. A
|
||||
dictionary of evaluation step name and evaluation step definition. If
|
||||
not provided, a default pointwise metric evaluation steps will be
|
||||
used.
|
||||
few_shot_examples: The optional list of few-shot examples to be used in
|
||||
the prompt, to provide the model with demonstrations of how to perform
|
||||
the evaluation, and improve the evaluation accuracy. If not provided,
|
||||
this field will not be included in the prompt template.
|
||||
"""
|
||||
if not input_variables:
|
||||
input_variables = []
|
||||
_LOGGER.info(
|
||||
"The `input_variables` parameter is empty. Only the `response`"
|
||||
" column is used for computing this model-based metric."
|
||||
)
|
||||
input_variables = list(set(input_variables + ["response"]))
|
||||
|
||||
instruction = instruction or self.get_default_pointwise_instruction()
|
||||
|
||||
evaluation_steps = (
|
||||
evaluation_steps or self.get_default_pointwise_evaluation_steps()
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
input_variables=input_variables,
|
||||
criteria=criteria,
|
||||
rating_rubric=rating_rubric,
|
||||
instruction=instruction,
|
||||
metric_definition=metric_definition,
|
||||
evaluation_steps=evaluation_steps,
|
||||
few_shot_examples=few_shot_examples,
|
||||
)
|
||||
|
||||
def get_default_pointwise_instruction(self) -> str:
|
||||
"""Returns the default instruction for the metric prompt template."""
|
||||
|
||||
return (
|
||||
"You are an expert evaluator. Your task is to evaluate the quality of"
|
||||
" the responses generated by AI models. We will provide you with the"
|
||||
" user prompt and an AI-generated responses.\nYou should first read"
|
||||
" the user input carefully for analyzing the task, and then evaluate"
|
||||
" the quality of the responses based on the Criteria provided in the"
|
||||
" Evaluation section below.\nYou will assign the response a rating"
|
||||
" following the Rating Rubric and Evaluation Steps. Give step by step"
|
||||
" explanations for your rating, and only choose ratings from the Rating"
|
||||
" Rubric."
|
||||
)
|
||||
|
||||
def get_default_pointwise_evaluation_steps(self) -> Dict[str, str]:
|
||||
"""Returns the default evaluation steps for the metric prompt template."""
|
||||
return {
|
||||
"Step 1": (
|
||||
"Assess the response in aspects of all criteria provided. Provide"
|
||||
" assessment according to each criterion."
|
||||
),
|
||||
"Step 2": (
|
||||
"Score based on the rating rubric. Give a brief rationale to"
|
||||
" explain your evaluation considering each individual criterion."
|
||||
),
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
"""Serializes the pointwise metric prompt template to a string."""
|
||||
metric_prompt_template_str = [
|
||||
"# Instruction",
|
||||
f"{self._instruction}",
|
||||
_NEWLINE,
|
||||
"# Evaluation",
|
||||
]
|
||||
if self._metric_definition:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Metric Definition",
|
||||
f"{self._metric_definition}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Criteria",
|
||||
f"{serialize_dict_in_order(self._criteria)}\n",
|
||||
"## Rating Rubric",
|
||||
f"{serialize_dict_in_order(self._rating_rubric)}\n",
|
||||
]
|
||||
)
|
||||
if self._evaluation_steps:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Evaluation Steps",
|
||||
f"{serialize_dict_in_order(self._evaluation_steps)}\n",
|
||||
]
|
||||
)
|
||||
if self._few_shot_examples:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Evaluation Examples",
|
||||
f"{_NEWLINE.join(self._few_shot_examples)}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
["\n# User Inputs and AI-generated Response", "## User Inputs"]
|
||||
)
|
||||
for input_variable in self._input_variables:
|
||||
if input_variable == "response":
|
||||
continue
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
f"### {input_variable}",
|
||||
f"{{{input_variable}}}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
_NEWLINE,
|
||||
"\n## AI-generated Response",
|
||||
"{response}",
|
||||
]
|
||||
)
|
||||
return _NEWLINE.join(metric_prompt_template_str)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"PointwiseMetricPromptTemplate(prompt_data={self.prompt_data},"
|
||||
f" variables={self.variables})"
|
||||
)
|
||||
|
||||
|
||||
class PairwiseMetricPromptTemplate(_MetricPromptTemplate):
|
||||
"""Pairwise metric prompt template for pairwise model-based metrics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
criteria: Dict[str, str],
|
||||
rating_rubric: Dict[str, str],
|
||||
input_variables: Optional[List[str]] = None,
|
||||
instruction: Optional[str] = None,
|
||||
metric_definition: Optional[str] = None,
|
||||
evaluation_steps: Optional[Dict[str, str]] = None,
|
||||
few_shot_examples: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initializes a pairwise metric prompt template.
|
||||
|
||||
Args:
|
||||
criteria: The standards and measures used to evaluate the model
|
||||
responses. It is a dictionary of criterion names and criterion
|
||||
definitions.
|
||||
rating_rubric: A dictionary mapping of rating name and rating
|
||||
definition, used to assign ratings or scores based on specific
|
||||
criteria.
|
||||
input_variables: An optional list of input fields to use in the metric
|
||||
prompt template for generating model-based evaluation results.
|
||||
Candidate model "response" column and "baseline_model_response" column
|
||||
are included by default. If metric_column_mapping is provided, the
|
||||
mapping values of the input fields will be used to retrieve data from
|
||||
the evaluation dataset.
|
||||
instruction: The general instruction to the model that performs the
|
||||
evaluation. If not provided, a default pairwise metric instruction
|
||||
will be used.
|
||||
metric_definition: The optional metric definition. It is a string
|
||||
describing the metric to be evaluated at a high level. If not
|
||||
provided, this field will not be included in the prompt template.
|
||||
evaluation_steps: The optional gudelines of evaluation steps. A
|
||||
dictionary of evaluation step name and evaluation step definition. If
|
||||
not provided, a default pairwise metric evaluation steps will be used.
|
||||
few_shot_examples: The optional list of few-shot examples to be used in
|
||||
the prompt, to provide the model with demonstrations of how to perform
|
||||
the evaluation, and improve the evaluation accuracy. If not provided,
|
||||
this field will not be included in the prompt template.
|
||||
"""
|
||||
if not input_variables:
|
||||
input_variables = []
|
||||
_LOGGER.info(
|
||||
"The `input_variables` parameter is empty. Only the `response`"
|
||||
" column and `baseline_model_response` columns are used for"
|
||||
" computing this model-based metric."
|
||||
)
|
||||
input_variables = list(
|
||||
set(input_variables + ["response", "baseline_model_response"])
|
||||
)
|
||||
|
||||
instruction = instruction or self.get_default_pairwise_instruction()
|
||||
|
||||
evaluation_steps = (
|
||||
evaluation_steps or self.get_default_pairwise_evaluation_steps()
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
input_variables=input_variables,
|
||||
criteria=criteria,
|
||||
rating_rubric=rating_rubric,
|
||||
instruction=instruction,
|
||||
metric_definition=metric_definition,
|
||||
evaluation_steps=evaluation_steps,
|
||||
few_shot_examples=few_shot_examples,
|
||||
)
|
||||
|
||||
def get_default_pairwise_instruction(self) -> str:
|
||||
"""Returns the default instruction for the metric prompt template."""
|
||||
|
||||
return (
|
||||
"You are an expert evaluator. Your task is to evaluate the quality of"
|
||||
" the responses generated by two AI models. We will provide you with"
|
||||
" the user input and a pair of AI-generated responses (Response A and"
|
||||
" Response B).\nYou should first read the user input carefully for"
|
||||
" analyzing the task, and then evaluate the quality of the responses"
|
||||
" based on based on the Criteria provided in the Evaluation section"
|
||||
" below.\nYou will first judge responses individually, following the"
|
||||
" Rating Rubric and Evaluation Steps. Then you will give step by step"
|
||||
" explanations for your judgement, compare results to declare the"
|
||||
" winner based on the Rating Rubric and Evaluation Steps."
|
||||
)
|
||||
|
||||
def get_default_pairwise_evaluation_steps(self) -> Dict[str, str]:
|
||||
"""Returns the default evaluation steps for the metric prompt template."""
|
||||
return {
|
||||
"Step 1": "Analyze Response A based on all the Criteria.",
|
||||
"Step 2": "Analyze Response B based on all the Criteria.",
|
||||
"Step 3": (
|
||||
"Compare the overall performance of Response A and Response B based"
|
||||
" on your analyses and assessment."
|
||||
),
|
||||
"Step 4": (
|
||||
'Output your preference of "A", "SAME" or "B" to the'
|
||||
" pairwise_choice field according to the Rating Rubrics."
|
||||
),
|
||||
"Step 5": "Output your assessment reasoning in the explanation field",
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
"""Serializes the pairwise metric prompt template to a string."""
|
||||
metric_prompt_template_str = [
|
||||
"# Instruction",
|
||||
f"{self._instruction}",
|
||||
_NEWLINE,
|
||||
"# Evaluation",
|
||||
]
|
||||
if self._metric_definition:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Metric Definition",
|
||||
f"{self._metric_definition}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Criteria",
|
||||
f"{serialize_dict_in_order(self._criteria)}\n",
|
||||
"## Rating Rubric",
|
||||
f"{serialize_dict_in_order(self._rating_rubric)}\n",
|
||||
]
|
||||
)
|
||||
if self._evaluation_steps:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Evaluation Steps",
|
||||
f"{serialize_dict_in_order(self._evaluation_steps)}\n",
|
||||
]
|
||||
)
|
||||
if self._few_shot_examples:
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"## Evaluation Examples",
|
||||
f"{_NEWLINE.join(self._few_shot_examples)}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
["\n# User Inputs and AI-generated Responses", "## User Inputs"]
|
||||
)
|
||||
for input_variable in self._input_variables:
|
||||
if input_variable in ["response", "baseline_model_response"]:
|
||||
continue
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
f"### {input_variable}",
|
||||
f"{{{input_variable}}}\n",
|
||||
]
|
||||
)
|
||||
metric_prompt_template_str.extend(
|
||||
[
|
||||
"\n## AI-generated Responses",
|
||||
"### Response A",
|
||||
"{baseline_model_response}\n",
|
||||
"### Response B",
|
||||
"{response}",
|
||||
]
|
||||
)
|
||||
return _NEWLINE.join(metric_prompt_template_str)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"PairwiseMetricPromptTemplate(prompt_data={self.prompt_data},"
|
||||
f" variables={self.variables})"
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Example metric prompt templates for model-based evaluation."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_default_templates,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
|
||||
|
||||
class MetricPromptTemplateExamples:
|
||||
"""Examples of metric prompt templates for model-based evaluation."""
|
||||
|
||||
_PROMPT_TEMPLATE_MAP = {
|
||||
constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE,
|
||||
constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE,
|
||||
constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE,
|
||||
constants.Metric.GROUNDEDNESS: (
|
||||
_default_templates.GROUNDEDNESS_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.INSTRUCTION_FOLLOWING: (
|
||||
_default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE,
|
||||
constants.Metric.TEXT_QUALITY: (
|
||||
_default_templates.TEXT_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.SUMMARIZATION_QUALITY: (
|
||||
_default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.QUESTION_ANSWERING_QUALITY: (
|
||||
_default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.MULTI_TURN_CHAT_QUALITY: (
|
||||
_default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.MULTI_TURN_SAFETY: (
|
||||
_default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_COHERENCE: (
|
||||
_default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_FLUENCY: (
|
||||
_default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_SAFETY: (
|
||||
_default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_GROUNDEDNESS: (
|
||||
_default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: (
|
||||
_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_VERBOSITY: (
|
||||
_default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_TEXT_QUALITY: (
|
||||
_default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: (
|
||||
_default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: (
|
||||
_default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: (
|
||||
_default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
|
||||
),
|
||||
constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: (
|
||||
_default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_prompt_template(cls, metric_name: str) -> str:
|
||||
"""Returns the prompt template for the given metric name."""
|
||||
return cls._PROMPT_TEMPLATE_MAP[metric_name]
|
||||
|
||||
@classmethod
|
||||
def list_example_metric_names(cls) -> List[str]:
|
||||
"""Returns a list of all metric prompt templates."""
|
||||
return list(cls._PROMPT_TEMPLATE_MAP.keys())
|
||||
|
||||
class Pointwise:
|
||||
"""Example PointwiseMetric instances."""
|
||||
|
||||
FLUENCY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.FLUENCY,
|
||||
metric_prompt_template=_default_templates.FLUENCY_PROMPT_TEMPLATE,
|
||||
)
|
||||
COHERENCE = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.COHERENCE,
|
||||
metric_prompt_template=_default_templates.COHERENCE_PROMPT_TEMPLATE,
|
||||
)
|
||||
SAFETY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.SAFETY,
|
||||
metric_prompt_template=_default_templates.SAFETY_PROMPT_TEMPLATE,
|
||||
)
|
||||
GROUNDEDNESS = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.GROUNDEDNESS,
|
||||
metric_prompt_template=_default_templates.GROUNDEDNESS_PROMPT_TEMPLATE,
|
||||
)
|
||||
INSTRUCTION_FOLLOWING = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.INSTRUCTION_FOLLOWING,
|
||||
metric_prompt_template=_default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE,
|
||||
)
|
||||
VERBOSITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.VERBOSITY,
|
||||
metric_prompt_template=_default_templates.VERBOSITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
TEXT_QUALITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.TEXT_QUALITY,
|
||||
metric_prompt_template=_default_templates.TEXT_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
SUMMARIZATION_QUALITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.SUMMARIZATION_QUALITY,
|
||||
metric_prompt_template=_default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
QUESTION_ANSWERING_QUALITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.QUESTION_ANSWERING_QUALITY,
|
||||
metric_prompt_template=_default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
MULTI_TURN_CHAT_QUALITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.MULTI_TURN_CHAT_QUALITY,
|
||||
metric_prompt_template=_default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
MULTI_TURN_SAFETY_QUALITY = pointwise_metric.PointwiseMetric(
|
||||
metric=constants.Metric.MULTI_TURN_SAFETY,
|
||||
metric_prompt_template=_default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
class Pairwise:
|
||||
"""Example PairwiseMetric instances."""
|
||||
|
||||
FLUENCY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_FLUENCY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE,
|
||||
)
|
||||
COHERENCE = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_COHERENCE,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE,
|
||||
)
|
||||
SAFETY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_SAFETY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE,
|
||||
)
|
||||
GROUNDEDNESS = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_GROUNDEDNESS,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE,
|
||||
)
|
||||
INSTRUCTION_FOLLOWING = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE,
|
||||
)
|
||||
VERBOSITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_VERBOSITY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
TEXT_QUALITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_TEXT_QUALITY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
SUMMARIZATION_QUALITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
QUESTION_ANSWERING_QUALITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
MULTI_TURN_CHAT_QUALITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE,
|
||||
)
|
||||
MULTI_TURN_SAFETY_QUALITY = pairwise_metric.PairwiseMetric(
|
||||
metric=constants.Metric.PAIRWISE_MULTI_TURN_SAFETY,
|
||||
metric_prompt_template=_default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE,
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Model-based Pairwise Metric."""
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview import generative_models
|
||||
from vertexai.preview.evaluation.metrics import _base
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config as custom_output_config_class,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
metric_prompt_template as metric_prompt_template_base,
|
||||
)
|
||||
|
||||
|
||||
class PairwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-access
|
||||
"""A Model-based Pairwise Metric.
|
||||
|
||||
A model-based evaluation metric that compares two generative models' responses
|
||||
side-by-side, and allows users to A/B test their generative models to
|
||||
determine which model is performing better.
|
||||
|
||||
For more details on when to use pairwise metrics, see
|
||||
[Evaluation methods and
|
||||
metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#pointwise_versus_pairwise).
|
||||
|
||||
Result Details:
|
||||
|
||||
* In `EvalResult.summary_metrics`, win rates for both the baseline and
|
||||
candidate model are computed. The win rate is computed as proportion of
|
||||
wins of one model's responses to total attempts as a decimal value
|
||||
between 0 and 1.
|
||||
|
||||
* In `EvalResult.metrics_table`, a pairwise metric produces two
|
||||
evaluation results per dataset row:
|
||||
* `pairwise_choice`: The choice shows whether the candidate model or
|
||||
the baseline model performs better, or if they are equally good.
|
||||
* `explanation`: The rationale behind each verdict using
|
||||
chain-of-thought reasoning. The explanation helps users scrutinize
|
||||
the judgment and builds appropriate trust in the decisions.
|
||||
|
||||
See [documentation
|
||||
page](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#understand-results)
|
||||
for more details on understanding the metric results.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
```
|
||||
baseline_model = GenerativeModel("gemini-1.0-pro")
|
||||
candidate_model = GenerativeModel("gemini-1.5-pro")
|
||||
|
||||
pairwise_groundedness = PairwiseMetric(
|
||||
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
|
||||
"pairwise_groundedness"
|
||||
),
|
||||
baseline_model=baseline_model,
|
||||
)
|
||||
eval_dataset = pd.DataFrame({
|
||||
"prompt" : [...],
|
||||
})
|
||||
pairwise_task = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[pairwise_groundedness],
|
||||
experiment="my-pairwise-experiment",
|
||||
)
|
||||
pairwise_result = pairwise_task.evaluate(
|
||||
model=candidate_model,
|
||||
experiment_run_name="gemini-pairwise-eval-run",
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metric: str,
|
||||
metric_prompt_template: Union[
|
||||
metric_prompt_template_base.PairwiseMetricPromptTemplate, str
|
||||
],
|
||||
baseline_model: Optional[
|
||||
Union[generative_models.GenerativeModel, Callable[[str], str]]
|
||||
] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
|
||||
custom_output_config: Optional[
|
||||
custom_output_config_class.CustomOutputConfig
|
||||
] = None,
|
||||
):
|
||||
"""Initializes a pairwise evaluation metric.
|
||||
|
||||
Args:
|
||||
metric: The pairwise evaluation metric name.
|
||||
metric_prompt_template: Pairwise metric prompt template for performing
|
||||
the pairwise model-based evaluation. A freeform string is also accepted.
|
||||
baseline_model: The baseline model for side-by-side comparison. If not
|
||||
specified, `baseline_model_response` column is required in the dataset
|
||||
to perform bring-your-own-response(BYOR) evaluation.
|
||||
system_instruction: The system instruction for the evaluation.
|
||||
autorater_config: The config for judge model.
|
||||
custom_output_config: Config for custom output from the judge model.
|
||||
"""
|
||||
super().__init__(
|
||||
metric_prompt_template=metric_prompt_template,
|
||||
metric=metric,
|
||||
system_instruction=system_instruction,
|
||||
autorater_config=autorater_config,
|
||||
custom_output_config=custom_output_config,
|
||||
)
|
||||
self._baseline_model = baseline_model
|
||||
|
||||
@property
|
||||
def baseline_model(
|
||||
self,
|
||||
) -> Union[generative_models.GenerativeModel, Callable[[str], str]]:
|
||||
return self._baseline_model
|
||||
@@ -0,0 +1,95 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Model-based Pointwise Metric."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import _base
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config as custom_output_config_class,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
metric_prompt_template as metric_prompt_template_base,
|
||||
)
|
||||
|
||||
|
||||
class PointwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-access
|
||||
"""A Model-based Pointwise Metric.
|
||||
|
||||
A model-based evaluation metric that evaluate a single generative model's
|
||||
response.
|
||||
|
||||
For more details on when to use model-based pointwise metrics, see
|
||||
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
|
||||
|
||||
Usage Examples:
|
||||
|
||||
```
|
||||
candidate_model = GenerativeModel("gemini-1.5-pro")
|
||||
eval_dataset = pd.DataFrame({
|
||||
"prompt" : [...],
|
||||
})
|
||||
fluency_metric = PointwiseMetric(
|
||||
metric="fluency",
|
||||
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template('fluency'),
|
||||
)
|
||||
pointwise_eval_task = EvalTask(
|
||||
dataset=eval_dataset,
|
||||
metrics=[
|
||||
fluency_metric,
|
||||
MetricPromptTemplateExamples.Pointwise.GROUNDEDNESS,
|
||||
],
|
||||
)
|
||||
pointwise_result = pointwise_eval_task.evaluate(
|
||||
model=candidate_model,
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metric: str,
|
||||
metric_prompt_template: Union[
|
||||
metric_prompt_template_base.PointwiseMetricPromptTemplate, str
|
||||
],
|
||||
system_instruction: Optional[str] = None,
|
||||
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
|
||||
custom_output_config: Optional[
|
||||
custom_output_config_class.CustomOutputConfig
|
||||
] = None,
|
||||
):
|
||||
"""Initializes a pointwise evaluation metric.
|
||||
|
||||
Args:
|
||||
metric: The pointwise evaluation metric name.
|
||||
metric_prompt_template: Pointwise metric prompt template for performing
|
||||
the model-based evaluation. A freeform string is also accepted.
|
||||
system_instruction: The system instruction for the evaluation.
|
||||
autorater_config: The config for judge model.
|
||||
custom_output_config: Config for custom output from the judge model.
|
||||
"""
|
||||
super().__init__(
|
||||
metric_prompt_template=metric_prompt_template,
|
||||
metric=metric,
|
||||
system_instruction=system_instruction,
|
||||
autorater_config=autorater_config,
|
||||
custom_output_config=custom_output_config,
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai.preview.evaluation import utils
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_base as metrics_base,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_default_templates,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
custom_output_config,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
rubric_based_metric,
|
||||
)
|
||||
|
||||
|
||||
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
|
||||
|
||||
_POINTWISE_OUTPUT_CONFIG = custom_output_config.CustomOutputConfig(
|
||||
return_raw_output=True,
|
||||
parsing_fn=utils.parse_pointwise_rubric_result,
|
||||
)
|
||||
|
||||
_PAIRWISE_OUTPUT_CONFIG = custom_output_config.CustomOutputConfig(
|
||||
return_raw_output=True,
|
||||
parsing_fn=utils.parse_pairwise_rubric_result,
|
||||
)
|
||||
_PAIRWISE_AUTORATER_CONFIG = AutoraterConfig(
|
||||
sampling_count=1,
|
||||
)
|
||||
|
||||
|
||||
class PredefinedRubricMetrics:
|
||||
"""Predefined rubric-based metrics."""
|
||||
|
||||
class Pointwise:
|
||||
"""Pointwise rubric-based metrics."""
|
||||
|
||||
INSTRUCTION_FOLLOWING = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_GENERATION_PROMPT_TEMPLATE,
|
||||
),
|
||||
critique_metric=pointwise_metric.PointwiseMetric(
|
||||
metric="rb_instruction_following",
|
||||
metric_prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
|
||||
),
|
||||
)
|
||||
MULTIMODAL_UNDERSTANDING = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_GENERATION_PROMPT_TEMPLATE
|
||||
),
|
||||
critique_metric=pointwise_metric.PointwiseMetric(
|
||||
metric="rb_multimodal_understanding",
|
||||
metric_prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
|
||||
),
|
||||
)
|
||||
TEXT_QUALITY = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_GENERATION_PROMPT_TEMPLATE
|
||||
),
|
||||
critique_metric=pointwise_metric.PointwiseMetric(
|
||||
metric="rb_text_quality",
|
||||
metric_prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
|
||||
),
|
||||
)
|
||||
|
||||
class Pairwise:
|
||||
"""Pairwise rubric-based metrics."""
|
||||
|
||||
INSTRUCTION_FOLLOWING = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_GENERATION_PROMPT_TEMPLATE,
|
||||
),
|
||||
critique_metric=pairwise_metric.PairwiseMetric(
|
||||
metric="pairwise_rb_instruction_following",
|
||||
metric_prompt_template=_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
|
||||
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
|
||||
),
|
||||
)
|
||||
MULTIMODAL_UNDERSTANDING = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_GENERATION_PROMPT_TEMPLATE
|
||||
),
|
||||
critique_metric=pairwise_metric.PairwiseMetric(
|
||||
metric="pairwise_rb_multimodal_understanding",
|
||||
metric_prompt_template=_default_templates.PAIRWISE_MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
|
||||
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
|
||||
),
|
||||
)
|
||||
TEXT_QUALITY = rubric_based_metric.RubricBasedMetric(
|
||||
generation_config=metrics_base.RubricGenerationConfig(
|
||||
prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_GENERATION_PROMPT_TEMPLATE
|
||||
),
|
||||
critique_metric=pairwise_metric.PairwiseMetric(
|
||||
metric="pairwise_rb_text_quality",
|
||||
metric_prompt_template=_default_templates.PAIRWISE_TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
|
||||
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
|
||||
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import collections
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview.evaluation import _pre_eval_utils
|
||||
from vertexai.preview.evaluation import constants
|
||||
from vertexai.preview.evaluation import utils
|
||||
from vertexai.preview.evaluation.metrics import (
|
||||
_base as metrics_base,
|
||||
)
|
||||
from vertexai.preview.evaluation.metrics import pairwise_metric
|
||||
from vertexai.preview.evaluation.metrics import pointwise_metric
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
_DEFAULT_MODEL_NAME = "gemini-2.0-flash-001"
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class RubricBasedMetric(metrics_base._Metric):
|
||||
"""Config for Rubric-Based Eval."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
generation_config: metrics_base.RubricGenerationConfig,
|
||||
critique_metric: Union[
|
||||
pointwise_metric.PointwiseMetric, pairwise_metric.PairwiseMetric
|
||||
]
|
||||
):
|
||||
"""Initializes RubricBasedMetric.
|
||||
|
||||
Args:
|
||||
generation_config: Config for rubric generation.
|
||||
critique_metric: Pointwise/pairwise metric for rubric critique.
|
||||
"""
|
||||
super().__init__(metric=critique_metric._metric)
|
||||
|
||||
self.generation_config = generation_config
|
||||
self.critique_metric = critique_metric
|
||||
|
||||
def generate_rubrics(
|
||||
self,
|
||||
eval_dataset: "pd.Dataframe",
|
||||
) -> "pd.DataFrame":
|
||||
"""Generates rubrics for given eval dataset."""
|
||||
if not self.generation_config.model:
|
||||
model = generative_models.GenerativeModel(model_name=_DEFAULT_MODEL_NAME)
|
||||
else:
|
||||
model = self.generation_config.model
|
||||
|
||||
if constants.Dataset.RUBRICS_COLUMN in eval_dataset.columns:
|
||||
_LOGGER.warning(
|
||||
"Rubrics column already exists in the dataset. Skipping rubric"
|
||||
" generation."
|
||||
)
|
||||
return eval_dataset
|
||||
|
||||
responses = _pre_eval_utils._generate_responses_from_gemini_model(
|
||||
model,
|
||||
eval_dataset,
|
||||
self.generation_config.prompt_template,
|
||||
)
|
||||
if self.generation_config.parsing_fn:
|
||||
parsing_fn = self.generation_config.parsing_fn
|
||||
else:
|
||||
parsing_fn = utils.parse_rubrics
|
||||
dataset_with_rubrics = eval_dataset.copy()
|
||||
aggregated = collections.defaultdict(list)
|
||||
for idx, response in enumerate(responses):
|
||||
result = parsing_fn(response)
|
||||
if isinstance(result, dict):
|
||||
questions = result.pop("questions", None)
|
||||
if questions is not None:
|
||||
aggregated[constants.Dataset.RUBRICS_COLUMN].append(
|
||||
(idx, questions)
|
||||
)
|
||||
for key, value in result.items():
|
||||
aggregated[key].append((idx, value))
|
||||
else:
|
||||
aggregated[constants.Dataset.RUBRICS_COLUMN].append((idx, result))
|
||||
for key, values in aggregated.items():
|
||||
dataset_with_rubrics[key] = None
|
||||
dataset_with_rubrics[key] = dataset_with_rubrics[key].astype(object)
|
||||
for idx, value in values:
|
||||
dataset_with_rubrics.at[idx, key] = value
|
||||
return dataset_with_rubrics
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Utility functions for multimodal evaluation."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Union, List, Set
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform_v1beta1.types import content
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
evaluation_service as gapic_eval_service_types,
|
||||
)
|
||||
from vertexai import generative_models
|
||||
from vertexai.preview.evaluation import (
|
||||
prompt_template as prompt_template_base,
|
||||
)
|
||||
|
||||
from google.protobuf import json_format
|
||||
|
||||
|
||||
ContentMap = gapic_eval_service_types.ContentMap
|
||||
Content = content.Content
|
||||
Part = content.Part
|
||||
_CONTENTS_DETECTOR = "contents {"
|
||||
_PARTS_DETECTOR = "parts {"
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
def _string_to_content_list(input_str: str) -> ContentMap.Contents:
|
||||
"""Converts a string to a list if possible, otherwise returns None."""
|
||||
try:
|
||||
return json_format.Parse(
|
||||
input_str,
|
||||
ContentMap.Contents.pb(ContentMap.Contents()),
|
||||
)
|
||||
except json_format.ParseError as e:
|
||||
if _CONTENTS_DETECTOR in input_str and _PARTS_DETECTOR in input_str:
|
||||
logging.warning(
|
||||
"Failed to parse %s to ContentMap.Contents: %s", input_str, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _is_multimodal_response(response: str) -> bool:
|
||||
"""Checks if the model response contains multimodal input."""
|
||||
content_list = _string_to_content_list(response)
|
||||
if content_list is None:
|
||||
if _CONTENTS_DETECTOR in response and _PARTS_DETECTOR in response:
|
||||
logging.warning(
|
||||
"Response contains multimodal input: %s. Please check whether"
|
||||
" the response format conforms to ContentMap type.",
|
||||
response,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_multimodal_instance(
|
||||
model_based_metric_instance_input: Dict[str, str],
|
||||
) -> bool:
|
||||
"""Checks if the evaluation instance contains multimodal input."""
|
||||
for placeholder in model_based_metric_instance_input:
|
||||
if _is_multimodal_response(model_based_metric_instance_input[placeholder]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_multimodal_response_to_content_map(
|
||||
model_based_metric_instance_input: Dict[str, str],
|
||||
) -> ContentMap:
|
||||
"""Converts a multimodal model response to a ContentMap."""
|
||||
content_map = ContentMap()
|
||||
for placeholder in model_based_metric_instance_input.keys():
|
||||
content_list = _string_to_content_list(
|
||||
model_based_metric_instance_input[placeholder]
|
||||
)
|
||||
if content_list is None:
|
||||
content_map.values[placeholder] = ContentMap.Contents(
|
||||
contents=[
|
||||
Content(
|
||||
parts=[
|
||||
Part(text=model_based_metric_instance_input[placeholder])
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
content_map.values[placeholder] = content_list
|
||||
|
||||
return content_map
|
||||
|
||||
|
||||
def _split_metric_prompt_template(
|
||||
metric_prompt_template: str,
|
||||
placeholders: Set[str],
|
||||
) -> List[str]:
|
||||
"""Splits the metric prompt template into a list of strings by placeholders."""
|
||||
placeholders_with_brackets = [
|
||||
re.escape("{" + placeholder + "}") for placeholder in placeholders
|
||||
]
|
||||
pattern = "|".join(f"({placeholder})" for placeholder in placeholders_with_brackets)
|
||||
split_metric_prompt_template = re.split(pattern, metric_prompt_template)
|
||||
return [element for element in split_metric_prompt_template if element]
|
||||
|
||||
|
||||
def _assemble_multi_modal_prompt(
|
||||
metric_prompt_template: Union[prompt_template_base.PromptTemplate, str],
|
||||
data_row: Dict[str, Any],
|
||||
row_index: int,
|
||||
placeholders: Set[str],
|
||||
) -> List[Union[str, generative_models.Part]]:
|
||||
"""Fills in the split metric prompt template elements with multimodal data to be sent to the model."""
|
||||
split_template_elements = _split_metric_prompt_template(
|
||||
str(metric_prompt_template), placeholders
|
||||
)
|
||||
part_inputs = []
|
||||
for element in split_template_elements:
|
||||
placeholder = element.replace("{", "").replace("}", "")
|
||||
if placeholder in data_row.keys():
|
||||
content_list = _string_to_content_list(data_row[placeholder])
|
||||
if content_list is None:
|
||||
part_inputs.append(data_row[placeholder])
|
||||
else:
|
||||
for content_inp in content_list.contents:
|
||||
for part in content_inp.parts:
|
||||
if part.HasField("text"):
|
||||
part_inputs.append(part.text)
|
||||
elif part.HasField("file_data"):
|
||||
part_inputs.append(
|
||||
generative_models.Part.from_uri(
|
||||
part.file_data.file_uri,
|
||||
mime_type=part.file_data.mime_type,
|
||||
)
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"The multimodal input you provided "
|
||||
f"at row {row_index} "
|
||||
"contains part types that are not "
|
||||
"yet supported. Currently supported"
|
||||
"part types are text and file_data"
|
||||
)
|
||||
else:
|
||||
part_inputs.append(element)
|
||||
return part_inputs
|
||||
@@ -0,0 +1,251 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Python functions which run only within a Jupyter or Colab notebook."""
|
||||
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vertexai.preview.evaluation import _base as eval_base
|
||||
from vertexai.preview.evaluation import constants
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pandas = None
|
||||
|
||||
_MARKDOWN_H2 = "##"
|
||||
_MARKDOWN_H3 = "###"
|
||||
_DEFAULT_COLUMNS_TO_DISPLAY = [
|
||||
constants.Dataset.MODEL_RESPONSE_COLUMN,
|
||||
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
|
||||
constants.Dataset.PROMPT_COLUMN,
|
||||
constants.MetricResult.ROW_COUNT_KEY,
|
||||
]
|
||||
_DEFAULT_RADAR_RANGE = (0, 5)
|
||||
|
||||
|
||||
def _get_ipython_shell_name() -> str:
|
||||
if "IPython" in sys.modules:
|
||||
# pylint: disable=g-import-not-at-top, g-importing-member
|
||||
from IPython import get_ipython
|
||||
|
||||
return get_ipython().__class__.__name__
|
||||
return ""
|
||||
|
||||
|
||||
def is_ipython_available() -> bool:
|
||||
return _get_ipython_shell_name()
|
||||
|
||||
|
||||
def _filter_df(
|
||||
df: pd.DataFrame, substrings: Optional[List[str]] = None
|
||||
) -> pd.DataFrame:
|
||||
"""Filters a DataFrame to include only columns containing the given substrings."""
|
||||
if substrings is None:
|
||||
return df
|
||||
|
||||
return df.copy().filter(
|
||||
[
|
||||
column_name
|
||||
for column_name in df.columns
|
||||
if any(substring in column_name for substring in substrings)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def display_eval_result(
|
||||
eval_result: "eval_base.EvalResult",
|
||||
title: Optional[str] = None,
|
||||
metrics: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Displays evaluation results in a notebook using IPython.display.
|
||||
|
||||
Args:
|
||||
eval_result: An object containing evaluation results with
|
||||
`summary_metrics` and `metrics_table` attributes.
|
||||
title: A string title to display above the results.
|
||||
metrics: A list of metric name substrings to filter displayed columns. If
|
||||
provided, only metrics whose names contain any of these strings will be
|
||||
displayed.
|
||||
"""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
# pylint: disable=g-import-not-at-top, g-importing-member
|
||||
from IPython.display import display
|
||||
from IPython.display import Markdown
|
||||
|
||||
summary_metrics, metrics_table = (
|
||||
eval_result.summary_metrics,
|
||||
eval_result.metrics_table,
|
||||
)
|
||||
|
||||
summary_metrics_df = pd.DataFrame.from_dict(summary_metrics, orient="index").T
|
||||
|
||||
if metrics:
|
||||
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
|
||||
summary_metrics_df = _filter_df(summary_metrics_df, columns_to_keep)
|
||||
metrics_table = _filter_df(metrics_table, columns_to_keep)
|
||||
|
||||
# Display the title in Markdown.
|
||||
if title:
|
||||
display(Markdown(f"{_MARKDOWN_H2} {title}"))
|
||||
|
||||
# Display the summary metrics.
|
||||
display(Markdown(f"{_MARKDOWN_H3} Summary Metrics"))
|
||||
display(summary_metrics_df)
|
||||
|
||||
# Display the metrics table.
|
||||
display(Markdown(f"{_MARKDOWN_H3} Row-based Metrics"))
|
||||
display(metrics_table)
|
||||
|
||||
|
||||
def display_explanations(
|
||||
eval_result: "eval_base.EvalResult",
|
||||
num: int = 1,
|
||||
metrics: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Displays the explanations in a notebook using IPython.display.
|
||||
|
||||
Args:
|
||||
eval_result: An object containing evaluation results. It is expected to
|
||||
have attributes `summary_metrics` and `metrics_table`.
|
||||
num: The number of row samples to display. Defaults to 1. If the number of
|
||||
rows is less than `num`, all rows will be displayed.
|
||||
metrics: A list of metric name substrings to filter displayed columns. If
|
||||
provided, only metrics whose names contain any of these strings will be
|
||||
displayed.
|
||||
"""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
# pylint: disable=g-import-not-at-top, g-importing-member
|
||||
from IPython.display import display
|
||||
from IPython.display import HTML
|
||||
|
||||
style = "white-space: pre-wrap; width: 1500px; overflow-x: auto;"
|
||||
metrics_table = eval_result.metrics_table
|
||||
|
||||
if num < 1:
|
||||
raise ValueError("Num must be greater than 0.")
|
||||
num = min(num, len(metrics_table))
|
||||
|
||||
df = metrics_table.sample(n=num)
|
||||
|
||||
if metrics:
|
||||
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
|
||||
df = _filter_df(df, columns_to_keep)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for col in df.columns:
|
||||
display(HTML(f"<div style='{style}'><h4>{col}:</h4>{row[col]}</div>"))
|
||||
display(HTML("<hr>"))
|
||||
|
||||
|
||||
def display_radar_plot(
|
||||
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
|
||||
metrics: List[str],
|
||||
radar_range: Tuple[float, float] = _DEFAULT_RADAR_RANGE,
|
||||
) -> None:
|
||||
"""Plots a radar plot comparing evaluation results.
|
||||
|
||||
Args:
|
||||
eval_results_with_title: List of (title, eval_result) tuples.
|
||||
metrics: A list of metrics whose mean values will be plotted.
|
||||
radar_range: Range of the radar plot axes.
|
||||
"""
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import plotly.graph_objects as go
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'`plotly` is not installed. Please install using "!pip install plotly"'
|
||||
) from exc
|
||||
|
||||
fig = go.Figure()
|
||||
for title, eval_result in eval_results_with_title:
|
||||
summary_metrics = eval_result.summary_metrics
|
||||
if metrics:
|
||||
summary_metrics = {
|
||||
key.replace("/mean", ""): summary_metrics[key]
|
||||
for key in summary_metrics
|
||||
if any(selected_metric + "/mean" in key for selected_metric in metrics)
|
||||
}
|
||||
fig.add_trace(
|
||||
go.Scatterpolar(
|
||||
r=list(summary_metrics.values()),
|
||||
theta=list(summary_metrics.keys()),
|
||||
fill="toself",
|
||||
name=title,
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
polar=dict(radialaxis=dict(visible=True, range=radar_range)),
|
||||
showlegend=True,
|
||||
)
|
||||
fig.show()
|
||||
|
||||
|
||||
def display_bar_plot(
|
||||
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
|
||||
metrics: List[str],
|
||||
) -> None:
|
||||
"""Plots a bar plot comparing evaluation results.
|
||||
|
||||
Args:
|
||||
eval_results_with_title: List of (title, eval_result) tuples.
|
||||
metrics: A list of metrics whose mean values will be plotted.
|
||||
"""
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import plotly.graph_objects as go
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'`plotly` is not installed. Please install using "!pip install plotly"'
|
||||
) from exc
|
||||
|
||||
data = []
|
||||
|
||||
for title, eval_result in eval_results_with_title:
|
||||
summary_metrics = eval_result.summary_metrics
|
||||
mean_summary_metrics = [f"{metric}/mean" for metric in metrics]
|
||||
updated_summary_metrics = []
|
||||
if metrics:
|
||||
for k, v in summary_metrics.items():
|
||||
if k in mean_summary_metrics:
|
||||
updated_summary_metrics.append((k, v))
|
||||
summary_metrics = dict(updated_summary_metrics)
|
||||
|
||||
data.append(
|
||||
go.Bar(
|
||||
x=list(summary_metrics.keys()),
|
||||
y=list(summary_metrics.values()),
|
||||
name=title,
|
||||
)
|
||||
)
|
||||
|
||||
fig = go.Figure(data=data)
|
||||
|
||||
fig.update_layout(barmode="group", showlegend=True)
|
||||
fig.show()
|
||||
|
||||
|
||||
def generate_uuid(length: int = 8) -> str:
|
||||
"""Generates a uuid of a specified length (default=8)."""
|
||||
return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))
|
||||
@@ -0,0 +1,86 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Prompt template for creating prompts with variables."""
|
||||
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
_VARIABLE_NAME_REGEX = r"\{([_a-zA-Z][_a-zA-Z0-9]*)\}"
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""A prompt template for creating prompts with variables.
|
||||
|
||||
The `PromptTemplate` class allows users to define a template string with
|
||||
variables represented in curly braces `{variable}`. The variable
|
||||
names cannot contain spaces and must start with a letter or underscore,
|
||||
followed by letters, digits, or underscore. These variables can be
|
||||
replaced with specific values using the `assemble` method, providing
|
||||
flexibility in generating dynamic prompts.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
template_str = "Hello, {name}! Today is {day}. How are you?"
|
||||
prompt_template = PromptTemplate(template_str)
|
||||
completed_prompt = prompt_template.assemble(name="John", day="Monday")
|
||||
print(completed_prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
"""Initializes the PromptTemplate with a given template.
|
||||
|
||||
Args:
|
||||
template: The template string with variables. Variables should be
|
||||
represented in curly braces `{variable}`.
|
||||
"""
|
||||
self.template = str(template)
|
||||
self.variables = self._get_variables()
|
||||
|
||||
def _get_variables(self) -> Set[str]:
|
||||
"""Extracts and return a set of variable names from the template."""
|
||||
return set(re.findall(_VARIABLE_NAME_REGEX, self.template))
|
||||
|
||||
def assemble(self, **kwargs) -> "PromptTemplate":
|
||||
"""Replaces only the provided variables in the template with specific values.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments where keys are placeholder names and values
|
||||
are the replacements.
|
||||
|
||||
Returns:
|
||||
A new PromptTemplate instance with the updated template string.
|
||||
"""
|
||||
assembled_string = self.template
|
||||
for variable_name, value in kwargs.items():
|
||||
if variable_name not in self.variables:
|
||||
raise ValueError(
|
||||
f"Invalid variable name '{variable_name}'. "
|
||||
f"Valid variables are: {self.variables}"
|
||||
)
|
||||
placeholder = "{" + variable_name + "}"
|
||||
assembled_string = assembled_string.replace(placeholder, str(value))
|
||||
return PromptTemplate(assembled_string)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns the template string."""
|
||||
return self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a string representation of the PromptTemplate."""
|
||||
return f"PromptTemplate('{self.template}')"
|
||||
@@ -0,0 +1,640 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Utility functions for evaluation."""
|
||||
|
||||
import functools
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import storage
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import compat
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.utils import _ipython_utils
|
||||
from google.cloud.aiplatform_v1beta1.services import (
|
||||
evaluation_service as gapic_evaluation_services,
|
||||
)
|
||||
from vertexai.evaluation import _base as eval_base
|
||||
from vertexai.evaluation.metrics import _base as metrics_base
|
||||
from vertexai.evaluation.metrics import (
|
||||
metric_prompt_template as metric_prompt_template_base,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_BQ_PREFIX = "bq://"
|
||||
_GCS_PREFIX = "gs://"
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
_QUESTION_REGEX = re.compile(r"Question:(.*?)Verdict:", re.DOTALL)
|
||||
_VERDICT_REGEX = re.compile("Verdict:(.*)")
|
||||
_QUESTION_BLOCK_REGEX = re.compile("<question>(.*?)</question>", re.DOTALL)
|
||||
_RESPONSE_A_REGEX = re.compile(
|
||||
r"\[\[Response A Answers:\]\](.*?)\[\[Rubric Score:", re.DOTALL
|
||||
)
|
||||
_RESPONSE_B_REGEX = re.compile(
|
||||
r"\[\[Response B Answers:\]\](.*?)\[\[Rubric Score:", re.DOTALL
|
||||
)
|
||||
_SXS_RATING_REGEX = re.compile(r"\[\[(SxSRating:[AB<>=]+)\]\]", re.DOTALL)
|
||||
|
||||
RATING_TO_VERDICT = {
|
||||
"B>>A": "Candidate response is better than the baseline response.",
|
||||
"A<<B": "Candiate response is better than the baseline response.",
|
||||
"B>A": "Candidate response is slightly better than the baseline response.",
|
||||
"A<B": "Candidate response is slightly better than the baseline response.",
|
||||
"A=B": "Both responses are equally good.",
|
||||
"B=A": "Both responses are equally good.",
|
||||
"A>B": "Baseline response is slightly better than the candidate response.",
|
||||
"B<A": "Baseline response is slightly better than the candidate response.",
|
||||
"B<<A": "Baseline response is better than the candidate response.",
|
||||
"A>>B": "Baseline response is better than the candidate response.",
|
||||
}
|
||||
_RATING_TO_SCORE = {
|
||||
"B>>A": 1,
|
||||
"A<<B": 1,
|
||||
"B>A": 0.5,
|
||||
"A<B": 0.5,
|
||||
"A=B": 0,
|
||||
"B=A": 0,
|
||||
"A>B": -0.5,
|
||||
"B<A": -0.5,
|
||||
"B<<A": -1,
|
||||
"A>>B": -1,
|
||||
}
|
||||
|
||||
|
||||
class _EvaluationServiceClientWithOverride(utils.ClientWithOverride):
|
||||
_is_temporary = False
|
||||
_default_version = compat.V1
|
||||
_version_map = (
|
||||
(
|
||||
compat.V1,
|
||||
gapic_evaluation_services.EvaluationServiceClient,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Helper class for rate-limiting requests to Vertex AI to improve QoS.
|
||||
|
||||
Attributes:
|
||||
seconds_per_event: The time interval (in seconds) between events to
|
||||
maintain the desired rate.
|
||||
last: The timestamp of the last event.
|
||||
_lock: A lock to ensure thread safety.
|
||||
"""
|
||||
|
||||
def __init__(self, rate: Optional[float] = None):
|
||||
"""Initializes the rate limiter.
|
||||
|
||||
A simple rate limiter for controlling the frequency of API calls. This class
|
||||
implements a token bucket algorithm to limit the rate at which events
|
||||
can occur. It's designed for cases where the batch size (number of events
|
||||
per call) is always 1 for traffic shaping and rate limiting.
|
||||
|
||||
Args:
|
||||
rate: The number of queries allowed per second.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rate is not positive.
|
||||
"""
|
||||
if not rate or rate <= 0:
|
||||
raise ValueError("Rate must be a positive number")
|
||||
self.seconds_per_event = 1.0 / rate
|
||||
self.last = time.time() - self.seconds_per_event
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _admit(self) -> float:
|
||||
"""Checks if an event can be admitted or calculates the remaining delay."""
|
||||
now = time.time()
|
||||
time_since_last = now - self.last
|
||||
if time_since_last >= self.seconds_per_event:
|
||||
self.last = now
|
||||
return 0
|
||||
else:
|
||||
return self.seconds_per_event - time_since_last
|
||||
|
||||
def sleep_and_advance(self):
|
||||
"""Blocks the current thread until the next event can be admitted."""
|
||||
with self._lock:
|
||||
delay = self._admit()
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
self.last = time.time()
|
||||
|
||||
|
||||
def rate_limit(rate: Optional[float] = None) -> Callable[[Any], Any]:
|
||||
"""Decorator version of rate limiter."""
|
||||
|
||||
def _rate_limit(method):
|
||||
limiter = RateLimiter(rate)
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(*args, **kwargs):
|
||||
limiter.sleep_and_advance()
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _rate_limit
|
||||
|
||||
|
||||
def create_evaluation_service_client(
|
||||
api_base_path_override: Optional[str] = None,
|
||||
) -> _EvaluationServiceClientWithOverride:
|
||||
"""Creates a client for the evaluation service.
|
||||
|
||||
Args:
|
||||
api_base_path_override: Optional. Override default api base path.
|
||||
|
||||
Returns:
|
||||
Instantiated Vertex AI EvaluationServiceClient with optional
|
||||
overrides.
|
||||
"""
|
||||
return initializer.global_config.create_client(
|
||||
client_class=_EvaluationServiceClientWithOverride,
|
||||
location_override=initializer.global_config.location,
|
||||
api_base_path_override=api_base_path_override,
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(
|
||||
source: Union[str, "pd.DataFrame", Dict[str, Any]],
|
||||
) -> "pd.DataFrame":
|
||||
"""Loads dataset from various sources into a DataFrame.
|
||||
|
||||
Args:
|
||||
source: The dataset source. Supports the following dataset formats:
|
||||
* pandas.DataFrame: Used directly for evaluation.
|
||||
* Dict: Converted to a pandas DataFrame before evaluation.
|
||||
* str: Interpreted as a file path or URI. Supported formats include:
|
||||
* Local JSONL or CSV files: Loaded from the local filesystem.
|
||||
* GCS JSONL or CSV files: Loaded from Google Cloud Storage (e.g.,
|
||||
'gs://bucket/data.csv').
|
||||
* BigQuery table URI: Loaded from Google Cloud
|
||||
BigQuery (e.g., 'bq://project-id.dataset.table_name').
|
||||
|
||||
Returns:
|
||||
The dataset in pandas DataFrame format.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Pandas is not installed. Please install the SDK using "pip install'
|
||||
' google-cloud-aiplatform[evaluation]"'
|
||||
)
|
||||
if "google.colab" in sys.modules:
|
||||
from google.colab import sheets
|
||||
|
||||
if isinstance(source, sheets.InteractiveSheet):
|
||||
return source.as_df().copy()
|
||||
|
||||
if isinstance(source, pd.DataFrame):
|
||||
return source.copy()
|
||||
elif isinstance(source, dict):
|
||||
return pd.DataFrame(source)
|
||||
elif isinstance(source, str):
|
||||
if source.startswith(_BQ_PREFIX):
|
||||
return _load_bigquery(source[len(_BQ_PREFIX) :])
|
||||
|
||||
_, extension = os.path.splitext(source)
|
||||
file_type = extension.lower()[1:]
|
||||
|
||||
if file_type == "jsonl":
|
||||
return _load_jsonl(source)
|
||||
elif file_type == "csv":
|
||||
return _load_csv(source)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file type: {file_type} from {source}. Please"
|
||||
" provide a valid GCS path with `jsonl` or `csv` suffix or a valid"
|
||||
" BigQuery table URI."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported dataset type. Must be a `pd.DataFrame`, Python dictionary,"
|
||||
" valid GCS path with `jsonl` or `csv` suffix or a valid BigQuery"
|
||||
" table URI."
|
||||
)
|
||||
|
||||
|
||||
def _load_jsonl(filepath: str) -> "pd.DataFrame":
|
||||
"""Loads data from a JSONL file into a DataFrame."""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Pandas is not installed. Please install the SDK using "pip install'
|
||||
' google-cloud-aiplatform[evaluation]"'
|
||||
)
|
||||
if filepath.startswith(_GCS_PREFIX):
|
||||
file_contents = _read_gcs_file_contents(filepath)
|
||||
return pd.read_json(file_contents, lines=True)
|
||||
else:
|
||||
with open(filepath, "r") as f:
|
||||
return pd.read_json(f, lines=True)
|
||||
|
||||
|
||||
def _load_csv(filepath: str) -> "pd.DataFrame":
|
||||
"""Loads data from a CSV file into a DataFrame."""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Pandas is not installed. Please install the SDK using "pip install'
|
||||
' google-cloud-aiplatform[evaluation]"'
|
||||
)
|
||||
if filepath.startswith(_GCS_PREFIX):
|
||||
file_contents = _read_gcs_file_contents(filepath)
|
||||
return pd.read_csv(io.StringIO(file_contents), encoding="utf-8")
|
||||
else:
|
||||
return pd.read_csv(filepath, encoding="utf-8")
|
||||
|
||||
|
||||
def _load_bigquery(table_id: str) -> "pd.DataFrame":
|
||||
"""Loads data from a BigQuery table into a DataFrame."""
|
||||
|
||||
bigquery_client = bigquery.Client(project=initializer.global_config.project)
|
||||
table = bigquery_client.get_table(table_id)
|
||||
return bigquery_client.list_rows(table).to_dataframe()
|
||||
|
||||
|
||||
def _read_gcs_file_contents(filepath: str) -> str:
|
||||
"""Reads the contents of a file from Google Cloud Storage.
|
||||
|
||||
Args:
|
||||
filepath: The GCS file path (e.g., 'gs://bucket_name/file.csv')
|
||||
|
||||
Returns:
|
||||
str: The contents of the file.
|
||||
"""
|
||||
|
||||
storage_client = storage.Client(
|
||||
project=initializer.global_config.project,
|
||||
credentials=initializer.global_config.credentials,
|
||||
)
|
||||
bucket_name, blob_path = filepath[len(_GCS_PREFIX) :].split("/", 1)
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(blob_path)
|
||||
return blob.download_as_string().decode("utf-8")
|
||||
|
||||
|
||||
def _upload_file_to_gcs(upload_gcs_path: str, filename: str) -> None:
|
||||
storage_client = storage.Client(
|
||||
project=initializer.global_config.project,
|
||||
credentials=initializer.global_config.credentials,
|
||||
)
|
||||
storage.Blob.from_string(
|
||||
uri=upload_gcs_path, client=storage_client
|
||||
).upload_from_filename(filename)
|
||||
|
||||
|
||||
def _upload_string_to_gcs(upload_gcs_path: str, contents: str) -> None:
|
||||
"""Uploads the provided string to a GCS bucket."""
|
||||
storage_client = storage.Client(
|
||||
project=initializer.global_config.project,
|
||||
credentials=initializer.global_config.credentials,
|
||||
)
|
||||
storage.Blob.from_string(
|
||||
uri=upload_gcs_path, client=storage_client
|
||||
).upload_from_string(contents)
|
||||
|
||||
|
||||
def _upload_pandas_df_to_gcs(
|
||||
df: "pd.DataFrame", upload_gcs_path: str, file_type: str
|
||||
) -> None:
|
||||
"""Uploads the provided Pandas DataFrame to a GCS bucket.
|
||||
|
||||
Args:
|
||||
df: The Pandas DataFrame to upload.
|
||||
upload_gcs_path: The GCS path to upload the data file.
|
||||
file_type: The file type of the data file.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if file_type == "csv":
|
||||
local_dataset_path = os.path.join(temp_dir, "metrics_table.csv")
|
||||
df.to_csv(path_or_buf=local_dataset_path)
|
||||
elif file_type == "jsonl":
|
||||
local_dataset_path = os.path.join(temp_dir, "metrics_table.jsonl")
|
||||
df.to_json(path_or_buf=local_dataset_path, orient="records", lines=True)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file type: {file_type} from {upload_gcs_path}."
|
||||
" Please provide a valid GCS path with `jsonl` or `csv` suffix."
|
||||
)
|
||||
|
||||
storage_client = storage.Client(
|
||||
project=initializer.global_config.project,
|
||||
credentials=initializer.global_config.credentials,
|
||||
)
|
||||
storage.Blob.from_string(
|
||||
uri=upload_gcs_path, client=storage_client
|
||||
).upload_from_filename(filename=local_dataset_path)
|
||||
|
||||
|
||||
def _upload_evaluation_summary_to_gcs(
|
||||
summary_metrics: Dict[str, float],
|
||||
upload_gcs_path: str,
|
||||
candidate_model_name: Optional[str] = None,
|
||||
baseline_model_name: Optional[str] = None,
|
||||
dataset_uri: Optional[str] = None,
|
||||
metrics: Optional[List[Union[str, metrics_base._Metric]]] = None,
|
||||
) -> None:
|
||||
"""Uploads the evaluation summary to a GCS bucket."""
|
||||
summary = {
|
||||
"summary_metrics": summary_metrics,
|
||||
}
|
||||
if candidate_model_name:
|
||||
summary["candidate_model_name"] = candidate_model_name
|
||||
if baseline_model_name:
|
||||
summary["baseline_model_name"] = baseline_model_name
|
||||
if dataset_uri:
|
||||
summary["dataset_uri"] = dataset_uri
|
||||
|
||||
if metrics:
|
||||
metric_descriptions = {}
|
||||
for metric in metrics:
|
||||
if isinstance(metric, metrics_base._ModelBasedMetric) and isinstance(
|
||||
metric._raw_metric_prompt_template,
|
||||
metric_prompt_template_base._MetricPromptTemplate,
|
||||
):
|
||||
metric_descriptions[metric.metric_name] = {
|
||||
"criteria": metric._raw_metric_prompt_template._criteria,
|
||||
"rating_rubric": metric._raw_metric_prompt_template._rating_rubric,
|
||||
}
|
||||
summary["metric_descriptions"] = metric_descriptions
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_summary_path = os.path.join(temp_dir, "summary_metrics.json")
|
||||
json.dump(summary, open(local_summary_path, "w"))
|
||||
_upload_file_to_gcs(upload_gcs_path, local_summary_path)
|
||||
|
||||
|
||||
def upload_evaluation_results(
|
||||
eval_result: eval_base.EvalResult,
|
||||
destination_uri_prefix: str,
|
||||
file_name: Optional[str] = None,
|
||||
candidate_model_name: Optional[str] = None,
|
||||
baseline_model_name: Optional[str] = None,
|
||||
dataset_uri: Optional[str] = None,
|
||||
metrics: Optional[List[Union[str, metrics_base._Metric]]] = None,
|
||||
) -> None:
|
||||
"""Uploads eval results to GCS destination.
|
||||
|
||||
Args:
|
||||
eval_result: Eval results to upload.
|
||||
destination_uri_prefix: GCS folder to store the data.
|
||||
file_name: Optional. File name to store the metrics table.
|
||||
candidate_model_name: Optional. Candidate model name.
|
||||
baseline_model_name: Optional. Baseline model name.
|
||||
dataset_uri: Optional. URI pointing to the dataset.
|
||||
metrics: Optional. List of metrics used for evaluation.
|
||||
"""
|
||||
if not destination_uri_prefix:
|
||||
_ipython_utils.display_gen_ai_evaluation_results_button()
|
||||
return
|
||||
if eval_result.metrics_table is None:
|
||||
return
|
||||
if destination_uri_prefix.startswith(_GCS_PREFIX):
|
||||
if file_name:
|
||||
base_name, extension = os.path.splitext(file_name)
|
||||
file_type = extension.lower()[1:]
|
||||
output_folder = destination_uri_prefix + "/" + base_name
|
||||
metrics_table_path = output_folder + "/" + file_name
|
||||
_upload_pandas_df_to_gcs(
|
||||
eval_result.metrics_table, metrics_table_path, file_type
|
||||
)
|
||||
_upload_evaluation_summary_to_gcs(
|
||||
eval_result.summary_metrics,
|
||||
output_folder + "/summary_metrics.json",
|
||||
candidate_model_name,
|
||||
baseline_model_name,
|
||||
dataset_uri,
|
||||
metrics,
|
||||
)
|
||||
_ipython_utils.display_gen_ai_evaluation_results_button(
|
||||
metrics_table_path.split(_GCS_PREFIX)[1]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported destination URI: {destination_uri_prefix}."
|
||||
f" Please provide a valid GCS bucket URI prefix starting with"
|
||||
f" {_GCS_PREFIX}."
|
||||
)
|
||||
|
||||
|
||||
def initialize_metric_column_mapping(
|
||||
metric_column_mapping: Optional[Dict[str, str]], dataset: "pd.DataFrame"
|
||||
):
|
||||
"""Initializes metric column mapping with dataset columns."""
|
||||
initialized_metric_column_mapping = {}
|
||||
for column in dataset.columns:
|
||||
initialized_metric_column_mapping[column] = column
|
||||
if metric_column_mapping:
|
||||
for key, value in metric_column_mapping.items():
|
||||
if key in initialized_metric_column_mapping:
|
||||
_LOGGER.warning(
|
||||
f"Cannot override `{key}` column with `{key}:{value}` mapping"
|
||||
f" because `{key}` column is present in the evaluation"
|
||||
" dataset. `metric_column_mapping` cannot override keys"
|
||||
" that are already in evaluation dataset columns."
|
||||
)
|
||||
else:
|
||||
initialized_metric_column_mapping[key] = value
|
||||
return initialized_metric_column_mapping
|
||||
|
||||
|
||||
def parse_intermediate_steps(intermediate_steps: List[Dict[str, Any]]):
|
||||
"""Parses intermediate steps from the response to create trajectory."""
|
||||
trajectory = []
|
||||
try:
|
||||
for step in intermediate_steps:
|
||||
step_input, _ = step[0], step[1]
|
||||
tool_name = step_input["kwargs"]["tool"]
|
||||
tool_input = step_input["kwargs"]["tool_input"]
|
||||
trajectory.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"tool_input": tool_input,
|
||||
}
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
_LOGGER.error(
|
||||
f"Failed to parse intermediate steps: {e}. The runnable you are using"
|
||||
" is likely not compatible with the evaluation service. Please ensure"
|
||||
" that the runnable you are using is compatible with the evaluation"
|
||||
" service, if not, consider building a custom runnable function."
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def parse_rubrics(rubric_generation_response: str) -> Dict[str, Any]:
|
||||
"""Parses the rubric generation responses."""
|
||||
try:
|
||||
_, response = rubric_generation_response.split("```json")
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Failed to parse rubric generation response. Does not contain ```json"
|
||||
)
|
||||
return {"questions": ""}
|
||||
try:
|
||||
result = json.loads(response.strip("\n` "))
|
||||
except json.JSONDecodeError:
|
||||
_LOGGER.warning(
|
||||
"Failed to parse rubric generation response. Does not contain valid"
|
||||
" JSON."
|
||||
)
|
||||
return {"questions": ""}
|
||||
return result
|
||||
|
||||
|
||||
def parse_pairwise_rubric_verdict_pairs(prediction: str, regex: Pattern[str]) -> str:
|
||||
"""Parses the pairwise rubric critique responses."""
|
||||
response = "Unable to parse rubric verdict pairs from response."
|
||||
response_matches = regex.findall(prediction)
|
||||
if response_matches:
|
||||
response_pairs = parse_question_blocks(response_matches[0])
|
||||
response = "\n".join(f"{q}: {v}" for q, v in response_pairs)
|
||||
return response
|
||||
|
||||
|
||||
def parse_pairwise_rubric_result(
|
||||
predictions: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Parses the pairwise rubric critique responses."""
|
||||
prediction = predictions[0] # currently only supports one sample
|
||||
rating_str = "Unable to parse verdict."
|
||||
|
||||
response_a = parse_pairwise_rubric_verdict_pairs(prediction, _RESPONSE_A_REGEX)
|
||||
response_b = parse_pairwise_rubric_verdict_pairs(prediction, _RESPONSE_B_REGEX)
|
||||
|
||||
sxs_rating_matches = _SXS_RATING_REGEX.findall(prediction.replace(" ", ""))
|
||||
|
||||
if sxs_rating_matches:
|
||||
rating_str = sxs_rating_matches[0].strip("[]")
|
||||
rating_str = rating_str[rating_str.find(":") + 1 :]
|
||||
|
||||
return {
|
||||
"pairwise_choice": (
|
||||
RATING_TO_VERDICT[rating_str]
|
||||
if rating_str in RATING_TO_VERDICT
|
||||
else rating_str
|
||||
),
|
||||
"score": (
|
||||
_RATING_TO_SCORE[rating_str] if rating_str in _RATING_TO_SCORE else None
|
||||
),
|
||||
"baseline_rubric_verdict_pairs": response_a,
|
||||
"candidate_rubric_verdict_pairs": response_b,
|
||||
"raw_outputs": predictions,
|
||||
}
|
||||
|
||||
|
||||
def parse_verdict(txt: str):
|
||||
"""Parses the verdict from the rubric critique response."""
|
||||
if not isinstance(txt, str) or not txt:
|
||||
return None
|
||||
try:
|
||||
if verdict := _VERDICT_REGEX.findall(txt):
|
||||
verdict = verdict[0]
|
||||
if "yes" in verdict.lower():
|
||||
return True
|
||||
elif "no" in verdict.lower():
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
|
||||
def parse_question(txt: str):
|
||||
"""Parses the question from the rubric critique response."""
|
||||
if not isinstance(txt, str) or not txt:
|
||||
return None
|
||||
try:
|
||||
txt = txt.split("Verdict:")[0]
|
||||
if "Question:" in txt:
|
||||
return txt.split("Question:")[-1].strip()
|
||||
|
||||
if not (question := _QUESTION_REGEX.findall(txt)):
|
||||
return txt.strip().split("\n")[0].removeprefix("STEP 1:").strip()
|
||||
return question[0].strip()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
|
||||
def parse_question_blocks(txt: str) -> List[Tuple[str, bool]]:
|
||||
"""Parses the question blocks from the rubric critique response."""
|
||||
if not txt.startswith("<question>\n"):
|
||||
txt = "<question>\n" + txt
|
||||
responses = []
|
||||
question_blocks = _QUESTION_BLOCK_REGEX.findall(txt)
|
||||
if not question_blocks:
|
||||
question_blocks = [txt]
|
||||
for block in question_blocks:
|
||||
q = parse_question(block)
|
||||
v = parse_verdict(block)
|
||||
if q is not None and v is not None:
|
||||
responses.append((q, v))
|
||||
return responses
|
||||
|
||||
|
||||
def parse_pointwise_rubric_result(results: List[str]) -> Dict[str, Any]:
|
||||
"""Parses the pointwise rubric critique responses."""
|
||||
self_consistency_results = {}
|
||||
for sample_result in results:
|
||||
rubric_verdict_pairs = parse_question_blocks(sample_result)
|
||||
for rubric, verdict in rubric_verdict_pairs:
|
||||
if rubric not in self_consistency_results:
|
||||
self_consistency_results[rubric] = 0
|
||||
|
||||
self_consistency_results[rubric] += 1 if verdict else -1
|
||||
|
||||
rubric_results = {}
|
||||
for rubric, verdict_counts in self_consistency_results.items():
|
||||
rubric_results[rubric] = verdict_counts > 0
|
||||
|
||||
rubric_results_str = "\n".join(f"{q}: {v}" for q, v in rubric_results.items())
|
||||
row_results = {
|
||||
"score": (
|
||||
sum(rubric_results.values()) / len(rubric_results) if rubric_results else 0
|
||||
)
|
||||
}
|
||||
row_results["rubric_verdict_pairs"] = rubric_results_str
|
||||
row_results["raw_outputs"] = results
|
||||
return row_results
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with example stores."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from google.cloud.aiplatform_v1beta1 import types
|
||||
from vertexai.example_stores._example_stores import (
|
||||
ContentsExample,
|
||||
ContentSearchKey,
|
||||
Example,
|
||||
ExampleStore,
|
||||
ExamplesArrayFilter,
|
||||
ExpectedContent,
|
||||
StoredContentsExample,
|
||||
StoredContentsExampleFilter,
|
||||
StoredContentsExampleParameters,
|
||||
)
|
||||
|
||||
ArrayOperator = types.ExamplesArrayFilter.ArrayOperator
|
||||
ExampleStoreConfig = types.ExampleStoreConfig
|
||||
|
||||
__all__ = (
|
||||
"ArrayOperator",
|
||||
"ContentsExample",
|
||||
"ContentSearchKey",
|
||||
"Example",
|
||||
"ExampleStore",
|
||||
"ExampleStoreConfig",
|
||||
"ExamplesArrayFilter",
|
||||
"ExpectedContent",
|
||||
"StoredContentsExample",
|
||||
"StoredContentsExampleFilter",
|
||||
"StoredContentsExampleParameters",
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with extensions."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.extensions._extensions import (
|
||||
Extension,
|
||||
)
|
||||
|
||||
__all__ = ("Extension",)
|
||||
@@ -0,0 +1,74 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with the Gemini models."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.generative_models._generative_models import (
|
||||
preview_grounding as grounding,
|
||||
_PreviewGenerativeModel,
|
||||
_PreviewChatSession,
|
||||
GenerationConfig,
|
||||
GenerationResponse,
|
||||
AutomaticFunctionCallingResponder,
|
||||
CallableFunctionDeclaration,
|
||||
Candidate,
|
||||
Content,
|
||||
FinishReason,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
HarmCategory,
|
||||
HarmBlockThreshold,
|
||||
Image,
|
||||
Part,
|
||||
ResponseBlockedError,
|
||||
ResponseValidationError,
|
||||
SafetySetting,
|
||||
Tool,
|
||||
ToolConfig,
|
||||
)
|
||||
|
||||
|
||||
class GenerativeModel(_PreviewGenerativeModel):
|
||||
__doc__ = _PreviewGenerativeModel.__doc__
|
||||
|
||||
|
||||
class ChatSession(_PreviewChatSession):
|
||||
__doc__ = _PreviewChatSession.__doc__
|
||||
|
||||
|
||||
__all__ = [
|
||||
"grounding",
|
||||
"GenerationConfig",
|
||||
"GenerativeModel",
|
||||
"GenerationResponse",
|
||||
"AutomaticFunctionCallingResponder",
|
||||
"CallableFunctionDeclaration",
|
||||
"Candidate",
|
||||
"ChatSession",
|
||||
"Content",
|
||||
"FinishReason",
|
||||
"FunctionCall",
|
||||
"FunctionDeclaration",
|
||||
"HarmCategory",
|
||||
"HarmBlockThreshold",
|
||||
"Image",
|
||||
"Part",
|
||||
"ResponseBlockedError",
|
||||
"ResponseValidationError",
|
||||
"SafetySetting",
|
||||
"Tool",
|
||||
"ToolConfig",
|
||||
]
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with language models."""
|
||||
|
||||
from vertexai.language_models._language_models import (
|
||||
_PreviewChatModel,
|
||||
_PreviewChatSession,
|
||||
_PreviewCodeChatModel,
|
||||
_PreviewCodeChatSession,
|
||||
_PreviewCodeGenerationModel,
|
||||
_PreviewTextEmbeddingModel,
|
||||
_PreviewTextGenerationModel,
|
||||
ChatMessage,
|
||||
CountTokensResponse,
|
||||
InputOutputTextPair,
|
||||
TextEmbedding,
|
||||
TextEmbeddingInput,
|
||||
TextGenerationResponse,
|
||||
TuningEvaluationSpec,
|
||||
)
|
||||
|
||||
from vertexai.language_models._evaluatable_language_models import (
|
||||
EvaluationTextGenerationSpec,
|
||||
EvaluationTextSummarizationSpec,
|
||||
EvaluationQuestionAnsweringSpec,
|
||||
EvaluationTextClassificationSpec,
|
||||
EvaluationClassificationMetric,
|
||||
EvaluationMetric,
|
||||
)
|
||||
|
||||
|
||||
ChatModel = _PreviewChatModel
|
||||
ChatSession = _PreviewChatSession
|
||||
CodeChatModel = _PreviewCodeChatModel
|
||||
CodeChatSession = _PreviewCodeChatSession
|
||||
CodeGenerationModel = _PreviewCodeGenerationModel
|
||||
TextGenerationModel = _PreviewTextGenerationModel
|
||||
TextEmbeddingModel = _PreviewTextEmbeddingModel
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"ChatModel",
|
||||
"ChatSession",
|
||||
"CodeChatModel",
|
||||
"CodeChatSession",
|
||||
"CodeGenerationModel",
|
||||
"CountTokensResponse",
|
||||
"EvaluationClassificationMetric",
|
||||
"EvaluationMetric",
|
||||
"EvaluationTextGenerationSpec",
|
||||
"EvaluationTextSummarizationSpec",
|
||||
"EvaluationQuestionAnsweringSpec",
|
||||
"EvaluationTextClassificationSpec",
|
||||
"InputOutputTextPair",
|
||||
"TextEmbedding",
|
||||
"TextEmbeddingInput",
|
||||
"TextEmbeddingModel",
|
||||
"TextGenerationModel",
|
||||
"TextGenerationResponse",
|
||||
"TuningEvaluationSpec",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes and functions for working with Model Garden."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.model_garden._model_garden import (
|
||||
OpenModel,
|
||||
list_deployable_models,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("OpenModel", "list_deployable_models")
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vertexai.prompts._prompts import (
|
||||
Prompt,
|
||||
)
|
||||
from vertexai.prompts._prompt_management import (
|
||||
create_version,
|
||||
delete,
|
||||
get,
|
||||
list_prompts as list,
|
||||
list_versions,
|
||||
restore_version,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Prompt",
|
||||
"delete",
|
||||
"create_version",
|
||||
"get",
|
||||
"list",
|
||||
"list_versions",
|
||||
"restore_version",
|
||||
]
|
||||
@@ -0,0 +1,124 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from vertexai.preview.rag.rag_data import (
|
||||
create_corpus,
|
||||
delete_corpus,
|
||||
delete_file,
|
||||
get_corpus,
|
||||
get_file,
|
||||
get_rag_engine_config,
|
||||
import_files,
|
||||
import_files_async,
|
||||
list_corpora,
|
||||
list_files,
|
||||
update_corpus,
|
||||
update_rag_engine_config,
|
||||
upload_file,
|
||||
)
|
||||
from vertexai.preview.rag.rag_retrieval import (
|
||||
retrieval_query,
|
||||
)
|
||||
from vertexai.preview.rag.rag_store import (
|
||||
Retrieval,
|
||||
VertexRagStore,
|
||||
)
|
||||
from vertexai.preview.rag.utils.resources import (
|
||||
ChunkingConfig,
|
||||
Basic,
|
||||
Enterprise,
|
||||
EmbeddingModelConfig,
|
||||
Filter,
|
||||
HybridSearch,
|
||||
JiraQuery,
|
||||
JiraSource,
|
||||
LayoutParserConfig,
|
||||
LlmParserConfig,
|
||||
LlmRanker,
|
||||
Pinecone,
|
||||
RagCorpus,
|
||||
RagEmbeddingModelConfig,
|
||||
RagEngineConfig,
|
||||
RagFile,
|
||||
RagManagedDb,
|
||||
RagManagedDbConfig,
|
||||
RagResource,
|
||||
RagRetrievalConfig,
|
||||
RagVectorDbConfig,
|
||||
RankService,
|
||||
Ranking,
|
||||
SharePointSource,
|
||||
SharePointSources,
|
||||
SlackChannel,
|
||||
SlackChannelsSource,
|
||||
TransformationConfig,
|
||||
VertexAiSearchConfig,
|
||||
VertexFeatureStore,
|
||||
VertexPredictionEndpoint,
|
||||
VertexVectorSearch,
|
||||
Weaviate,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"ChunkingConfig",
|
||||
"Basic",
|
||||
"Enterprise",
|
||||
"EmbeddingModelConfig",
|
||||
"Filter",
|
||||
"HybridSearch",
|
||||
"JiraQuery",
|
||||
"JiraSource",
|
||||
"LayoutParserConfig",
|
||||
"LlmParserConfig",
|
||||
"LlmRanker",
|
||||
"Pinecone",
|
||||
"RagEngineConfig",
|
||||
"RagCorpus",
|
||||
"RagFile",
|
||||
"RagManagedDb",
|
||||
"RagManagedDbConfig",
|
||||
"RagResource",
|
||||
"RagRetrievalConfig",
|
||||
"Ranking",
|
||||
"RankService",
|
||||
"Retrieval",
|
||||
"SharePointSource",
|
||||
"SharePointSources",
|
||||
"SlackChannel",
|
||||
"SlackChannelsSource",
|
||||
"TransformationConfig",
|
||||
"VertexAiSearchConfig",
|
||||
"VertexFeatureStore",
|
||||
"VertexRagStore",
|
||||
"VertexVectorSearch",
|
||||
"Weaviate",
|
||||
"RagEmbeddingModelConfig",
|
||||
"VertexPredictionEndpoint",
|
||||
"RagVectorDbConfig",
|
||||
"create_corpus",
|
||||
"delete_corpus",
|
||||
"delete_file",
|
||||
"get_corpus",
|
||||
"get_file",
|
||||
"import_files",
|
||||
"import_files_async",
|
||||
"list_corpora",
|
||||
"list_files",
|
||||
"retrieval_query",
|
||||
"upload_file",
|
||||
"update_corpus",
|
||||
"update_rag_engine_config",
|
||||
"get_rag_engine_config",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1001
.venv/lib/python3.10/site-packages/vertexai/preview/rag/rag_data.py
Normal file
1001
.venv/lib/python3.10/site-packages/vertexai/preview/rag/rag_data.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Retrieval query to get relevant contexts."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
import warnings
|
||||
|
||||
from google.cloud import aiplatform_v1beta1
|
||||
from google.cloud.aiplatform import initializer
|
||||
from vertexai.preview.rag.utils import _gapic_utils
|
||||
from vertexai.preview.rag.utils import resources
|
||||
|
||||
|
||||
def retrieval_query(
|
||||
text: str,
|
||||
rag_resources: Optional[List[resources.RagResource]] = None,
|
||||
rag_corpora: Optional[List[str]] = None,
|
||||
similarity_top_k: Optional[int] = None,
|
||||
vector_distance_threshold: Optional[float] = None,
|
||||
vector_search_alpha: Optional[float] = None,
|
||||
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
|
||||
) -> aiplatform_v1beta1.RetrieveContextsResponse:
|
||||
"""Retrieve top k relevant docs/chunks.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
# Using deprecated parameters
|
||||
results = vertexai.preview.rag.retrieval_query(
|
||||
text="Why is the sky blue?",
|
||||
rag_resources=[vertexai.preview.rag.RagResource(
|
||||
rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1",
|
||||
rag_file_ids=["rag-file-1", "rag-file-2", ...],
|
||||
)],
|
||||
similarity_top_k=2,
|
||||
vector_distance_threshold=0.5,
|
||||
vector_search_alpha=0.5,
|
||||
)
|
||||
|
||||
# Using RagRetrievalConfig. Equivalent to the above example.
|
||||
config = vertexai.preview.rag.RagRetrievalConfig(
|
||||
top_k=2,
|
||||
filter=vertexai.preview.rag.Filter(
|
||||
vector_distance_threshold=0.5
|
||||
),
|
||||
hybrid_search=vertexai.preview.rag.rag_retrieval_config.hybrid_search(
|
||||
alpha=0.5
|
||||
),
|
||||
ranking=vertex.preview.rag.Ranking(
|
||||
llm_ranker=vertexai.preview.rag.LlmRanker(
|
||||
model_name="gemini-1.5-flash-002"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results = vertexai.preview.rag.retrieval_query(
|
||||
text="Why is the sky blue?",
|
||||
rag_resources=[vertexai.preview.rag.RagResource(
|
||||
rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1",
|
||||
rag_file_ids=["rag-file-1", "rag-file-2", ...],
|
||||
)],
|
||||
rag_retrieval_config=config,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
text: The query in text format to get relevant contexts.
|
||||
rag_resources: A list of RagResource. It can be used to specify corpus
|
||||
only or ragfiles. Currently only support one corpus or multiple files
|
||||
from one corpus. In the future we may open up multiple corpora support.
|
||||
rag_corpora: If rag_resources is not specified, use rag_corpora as a list
|
||||
of rag corpora names. Deprecated. Use rag_resources instead.
|
||||
similarity_top_k: The number of contexts to retrieve. Deprecated. Use
|
||||
rag_retrieval_config.top_k instead.
|
||||
vector_distance_threshold: Optional. Only return contexts with vector
|
||||
distance smaller than the threshold. Deprecated. Use
|
||||
rag_retrieval_config.filter.vector_distance_threshold instead.
|
||||
vector_search_alpha: Optional. Controls the weight between dense and
|
||||
sparse vector search results. The range is [0, 1], where 0 means
|
||||
sparse vector search only and 1 means dense vector search only.
|
||||
The default value is 0.5. Deprecated. Use
|
||||
rag_retrieval_config.hybrid_search.alpha instead.
|
||||
rag_retrieval_config: Optional. The config containing the retrieval
|
||||
parameters, including top_k, vector_distance_threshold,
|
||||
and alpha.
|
||||
|
||||
Returns:
|
||||
RetrieveContextsResonse.
|
||||
"""
|
||||
parent = initializer.global_config.common_location_path()
|
||||
|
||||
client = _gapic_utils.create_rag_service_client()
|
||||
|
||||
if rag_resources:
|
||||
if len(rag_resources) > 1:
|
||||
raise ValueError("Currently only support 1 RagResource.")
|
||||
name = rag_resources[0].rag_corpus
|
||||
elif rag_corpora:
|
||||
if len(rag_corpora) > 1:
|
||||
raise ValueError("Currently only support 1 RagCorpus.")
|
||||
name = rag_corpora[0]
|
||||
warnings.warn(
|
||||
f"rag_corpora is deprecated. Please use rag_resources instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" rag_corpora will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
else:
|
||||
raise ValueError("rag_resources or rag_corpora must be specified.")
|
||||
|
||||
data_client = _gapic_utils.create_rag_data_service_client()
|
||||
if data_client.parse_rag_corpus_path(name):
|
||||
rag_corpus_name = name
|
||||
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
|
||||
rag_corpus_name = parent + "/ragCorpora/" + name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid RagCorpus name: {rag_corpora}. Proper format should be:"
|
||||
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
|
||||
)
|
||||
|
||||
if rag_resources:
|
||||
gapic_rag_resource = (
|
||||
aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore.RagResource(
|
||||
rag_corpus=rag_corpus_name,
|
||||
rag_file_ids=rag_resources[0].rag_file_ids,
|
||||
)
|
||||
)
|
||||
vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore(
|
||||
rag_resources=[gapic_rag_resource],
|
||||
)
|
||||
else:
|
||||
vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore(
|
||||
rag_corpora=[rag_corpus_name],
|
||||
)
|
||||
|
||||
# Check for deprecated parameters and raise warnings.
|
||||
if similarity_top_k:
|
||||
# If similarity_top_k is specified, throw deprecation warning.
|
||||
warnings.warn(
|
||||
"similarity_top_k is deprecated. Please use"
|
||||
" rag_retrieval_config.top_k instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" similarity_top_k will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if vector_search_alpha:
|
||||
# If vector_search_alpha is specified, throw deprecation warning.
|
||||
warnings.warn(
|
||||
"vector_search_alpha is deprecated. Please use"
|
||||
" rag_retrieval_config.alpha instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" vector_search_alpha will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if vector_distance_threshold:
|
||||
# If vector_distance_threshold is specified, throw deprecation warning.
|
||||
warnings.warn(
|
||||
"vector_distance_threshold is deprecated. Please use"
|
||||
" rag_retrieval_config.filter.vector_distance_threshold instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" vector_distance_threshold will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# If rag_retrieval_config is not specified, set it to default values.
|
||||
if not rag_retrieval_config:
|
||||
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig(
|
||||
top_k=similarity_top_k,
|
||||
hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch(
|
||||
alpha=vector_search_alpha,
|
||||
),
|
||||
filter=aiplatform_v1beta1.RagRetrievalConfig.Filter(
|
||||
vector_distance_threshold=vector_distance_threshold
|
||||
),
|
||||
)
|
||||
else:
|
||||
# If rag_retrieval_config is specified, check for missing parameters.
|
||||
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
|
||||
# Set top_k to config value if specified
|
||||
if rag_retrieval_config.top_k:
|
||||
api_retrival_config.top_k = rag_retrieval_config.top_k
|
||||
else:
|
||||
api_retrival_config.top_k = similarity_top_k
|
||||
# Set alpha to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.hybrid_search
|
||||
and rag_retrieval_config.hybrid_search.alpha
|
||||
):
|
||||
api_retrival_config.hybrid_search.alpha = (
|
||||
rag_retrieval_config.hybrid_search.alpha
|
||||
)
|
||||
else:
|
||||
api_retrival_config.hybrid_search.alpha = vector_search_alpha
|
||||
# Check if both vector_distance_threshold and vector_similarity_threshold
|
||||
# are specified.
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of vector_distance_threshold or"
|
||||
" vector_similarity_threshold can be specified at a time"
|
||||
" in rag_retrieval_config."
|
||||
)
|
||||
# Set vector_distance_threshold to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
):
|
||||
api_retrival_config.filter.vector_distance_threshold = (
|
||||
rag_retrieval_config.filter.vector_distance_threshold
|
||||
)
|
||||
else:
|
||||
api_retrival_config.filter.vector_distance_threshold = (
|
||||
vector_distance_threshold
|
||||
)
|
||||
# Set vector_similarity_threshold to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
api_retrival_config.filter.vector_similarity_threshold = (
|
||||
rag_retrieval_config.filter.vector_similarity_threshold
|
||||
)
|
||||
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
and rag_retrieval_config.ranking.llm_ranker
|
||||
):
|
||||
raise ValueError("Only one of rank_service and llm_ranker can be set.")
|
||||
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
|
||||
api_retrival_config.ranking.rank_service.model_name = (
|
||||
rag_retrieval_config.ranking.rank_service.model_name
|
||||
)
|
||||
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
|
||||
api_retrival_config.ranking.llm_ranker.model_name = (
|
||||
rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
)
|
||||
query = aiplatform_v1beta1.RagQuery(
|
||||
text=text,
|
||||
rag_retrieval_config=api_retrival_config,
|
||||
)
|
||||
request = aiplatform_v1beta1.RetrieveContextsRequest(
|
||||
vertex_rag_store=vertex_rag_store,
|
||||
parent=parent,
|
||||
query=query,
|
||||
)
|
||||
try:
|
||||
response = client.retrieve_contexts(request=request)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed in retrieving contexts due to: ", e) from e
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""RAG retrieval tool for content generation."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
import warnings
|
||||
|
||||
from google.cloud import aiplatform_v1beta1
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
|
||||
from vertexai.preview import generative_models
|
||||
from vertexai.preview.rag.utils import _gapic_utils
|
||||
from vertexai.preview.rag.utils import resources
|
||||
|
||||
|
||||
class Retrieval(generative_models.grounding.Retrieval):
|
||||
"""Defines a retrieval tool that a model can call to access external knowledge."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Union["VertexRagStore"],
|
||||
disable_attribution: Optional[bool] = False,
|
||||
):
|
||||
self._raw_retrieval = gapic_tool_types.Retrieval(
|
||||
vertex_rag_store=source._raw_vertex_rag_store,
|
||||
disable_attribution=disable_attribution,
|
||||
)
|
||||
|
||||
|
||||
class VertexRagStore:
|
||||
"""Retrieve from Vertex RAG Store."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rag_resources: Optional[List[resources.RagResource]] = None,
|
||||
rag_corpora: Optional[List[str]] = None,
|
||||
similarity_top_k: Optional[int] = None,
|
||||
vector_distance_threshold: Optional[float] = None,
|
||||
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
|
||||
):
|
||||
"""Initializes a Vertex RAG store tool.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import vertexai
|
||||
|
||||
vertexai.init(project="my-project")
|
||||
|
||||
# Using deprecated parameters
|
||||
tool = Tool.from_retrieval(
|
||||
retrieval=vertexai.preview.rag.Retrieval(
|
||||
source=vertexai.preview.rag.VertexRagStore(
|
||||
rag_corpora=["projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1"],
|
||||
similarity_top_k=3,
|
||||
vector_distance_threshold=0.4,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Using RagRetrievalConfig. Equivalent to the above example.
|
||||
config = vertexai.preview.rag.RagRetrievalConfig(
|
||||
top_k=2,
|
||||
filter=vertexai.preview.rag.RagRetrievalConfig.Filter(
|
||||
vector_distance_threshold=0.5
|
||||
),
|
||||
)
|
||||
|
||||
tool = Tool.from_retrieval(
|
||||
retrieval=vertexai.preview.rag.Retrieval(
|
||||
source=vertexai.preview.rag.VertexRagStore(
|
||||
rag_corpora=["projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1"],
|
||||
rag_retrieval_config=config,
|
||||
),
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
rag_resources: List of RagResource to retrieve from. It can be used
|
||||
to specify corpus only or ragfiles. Currently only support one
|
||||
corpus or multiple files from one corpus. In the future we
|
||||
may open up multiple corpora support.
|
||||
rag_corpora: If rag_resources is not specified, use rag_corpora as a
|
||||
list of rag corpora names. Deprecated. Use rag_resources instead.
|
||||
similarity_top_k: Number of top k results to return from the selected
|
||||
corpora. Deprecated. Use rag_retrieval_config.top_k instead.
|
||||
vector_distance_threshold (float):
|
||||
Optional. Only return results with vector distance smaller
|
||||
than the threshold. Deprecated. Use
|
||||
rag_retrieval_config.filter.vector_distance_threshold instead.
|
||||
rag_retrieval_config: Optional. The config containing the retrieval
|
||||
parameters, including top_k and vector_distance_threshold.
|
||||
"""
|
||||
|
||||
if rag_resources:
|
||||
if len(rag_resources) > 1:
|
||||
raise ValueError("Currently only support 1 RagResource.")
|
||||
name = rag_resources[0].rag_corpus
|
||||
elif rag_corpora:
|
||||
if len(rag_corpora) > 1:
|
||||
raise ValueError("Currently only support 1 RagCorpus.")
|
||||
warnings.warn(
|
||||
"rag_corpora is deprecated. Please use rag_resources instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" rag_corpora will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
name = rag_corpora[0]
|
||||
else:
|
||||
raise ValueError("rag_resources or rag_corpora must be specified.")
|
||||
|
||||
data_client = _gapic_utils.create_rag_data_service_client()
|
||||
if data_client.parse_rag_corpus_path(name):
|
||||
rag_corpus_name = name
|
||||
elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name):
|
||||
parent = initializer.global_config.common_location_path()
|
||||
rag_corpus_name = parent + "/ragCorpora/" + name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid RagCorpus name: {rag_corpora}. Proper format should"
|
||||
+ " be: projects/{{project}}/locations/{{location}}/ragCorpora/{{rag_corpus_id}}"
|
||||
)
|
||||
|
||||
# Check for deprecated parameters and raise warnings.
|
||||
if similarity_top_k:
|
||||
# If similarity_top_k is specified, throw deprecation warning.
|
||||
warnings.warn(
|
||||
"similarity_top_k is deprecated. Please use"
|
||||
" rag_retrieval_config.top_k instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" similarity_top_k will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if vector_distance_threshold:
|
||||
# If vector_distance_threshold is specified, throw deprecation warning.
|
||||
warnings.warn(
|
||||
"vector_distance_threshold is deprecated. Please use"
|
||||
" rag_retrieval_config.filter.vector_distance_threshold instead."
|
||||
f" After {resources.DEPRECATION_DATE} using"
|
||||
" vector_distance_threshold will raise error",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# If rag_retrieval_config is not specified, set it to default values.
|
||||
if not rag_retrieval_config:
|
||||
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig(
|
||||
top_k=similarity_top_k,
|
||||
filter=aiplatform_v1beta1.RagRetrievalConfig.Filter(
|
||||
vector_distance_threshold=vector_distance_threshold
|
||||
),
|
||||
)
|
||||
else:
|
||||
# If rag_retrieval_config is specified, check for missing parameters.
|
||||
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
|
||||
# Set top_k to config value if specified
|
||||
if rag_retrieval_config.top_k:
|
||||
api_retrival_config.top_k = rag_retrieval_config.top_k
|
||||
else:
|
||||
api_retrival_config.top_k = similarity_top_k
|
||||
# Check if both vector_distance_threshold and vector_similarity_threshold
|
||||
# are specified.
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of vector_distance_threshold or"
|
||||
" vector_similarity_threshold can be specified at a time"
|
||||
" in rag_retrieval_config."
|
||||
)
|
||||
# Set vector_distance_threshold to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_distance_threshold
|
||||
):
|
||||
api_retrival_config.filter.vector_distance_threshold = (
|
||||
rag_retrieval_config.filter.vector_distance_threshold
|
||||
)
|
||||
else:
|
||||
api_retrival_config.filter.vector_distance_threshold = (
|
||||
vector_distance_threshold
|
||||
)
|
||||
# Set vector_similarity_threshold to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.filter
|
||||
and rag_retrieval_config.filter.vector_similarity_threshold
|
||||
):
|
||||
api_retrival_config.filter.vector_similarity_threshold = (
|
||||
rag_retrieval_config.filter.vector_similarity_threshold
|
||||
)
|
||||
# Check if both rank_service and llm_ranker are specified.
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
and rag_retrieval_config.ranking.rank_service.model_name
|
||||
and rag_retrieval_config.ranking.llm_ranker
|
||||
and rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of rank_service or llm_ranker can be specified"
|
||||
" at a time in rag_retrieval_config."
|
||||
)
|
||||
# Set rank_service to config value if specified
|
||||
if (
|
||||
rag_retrieval_config.ranking
|
||||
and rag_retrieval_config.ranking.rank_service
|
||||
):
|
||||
api_retrival_config.ranking.rank_service.model_name = (
|
||||
rag_retrieval_config.ranking.rank_service.model_name
|
||||
)
|
||||
# Set llm_ranker to config value if specified
|
||||
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
|
||||
api_retrival_config.ranking.llm_ranker.model_name = (
|
||||
rag_retrieval_config.ranking.llm_ranker.model_name
|
||||
)
|
||||
|
||||
if rag_resources:
|
||||
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
|
||||
rag_corpus=rag_corpus_name,
|
||||
rag_file_ids=rag_resources[0].rag_file_ids,
|
||||
)
|
||||
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
|
||||
rag_resources=[gapic_rag_resource],
|
||||
rag_retrieval_config=api_retrival_config,
|
||||
)
|
||||
else:
|
||||
self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore(
|
||||
rag_corpora=[rag_corpus_name],
|
||||
rag_retrieval_config=api_retrival_config,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,849 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.utils import (
|
||||
VertexRagClientWithOverride,
|
||||
VertexRagDataAsyncClientWithOverride,
|
||||
VertexRagDataClientWithOverride,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1 import (
|
||||
GoogleDriveSource,
|
||||
ImportRagFilesConfig,
|
||||
ImportRagFilesRequest,
|
||||
JiraSource as GapicJiraSource,
|
||||
RagCorpus as GapicRagCorpus,
|
||||
RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig,
|
||||
RagEngineConfig as GapicRagEngineConfig,
|
||||
RagFileChunkingConfig,
|
||||
RagFileParsingConfig,
|
||||
RagFileTransformationConfig,
|
||||
RagFile as GapicRagFile,
|
||||
RagManagedDbConfig as GapicRagManagedDbConfig,
|
||||
RagVectorDbConfig as GapicRagVectorDbConfig,
|
||||
SharePointSources as GapicSharePointSources,
|
||||
SlackSource as GapicSlackSource,
|
||||
VertexAiSearchConfig as GapicVertexAiSearchConfig,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import api_auth
|
||||
from vertexai.preview.rag.utils.resources import (
|
||||
EmbeddingModelConfig,
|
||||
JiraSource,
|
||||
LayoutParserConfig,
|
||||
LlmParserConfig,
|
||||
Pinecone,
|
||||
RagCorpus,
|
||||
RagEmbeddingModelConfig,
|
||||
RagEngineConfig,
|
||||
RagFile,
|
||||
RagManagedDb,
|
||||
RagManagedDbConfig,
|
||||
RagVectorDbConfig,
|
||||
Basic,
|
||||
Enterprise,
|
||||
SharePointSources,
|
||||
SlackChannelsSource,
|
||||
TransformationConfig,
|
||||
VertexAiSearchConfig,
|
||||
VertexFeatureStore,
|
||||
VertexPredictionEndpoint,
|
||||
VertexVectorSearch,
|
||||
Weaviate,
|
||||
)
|
||||
|
||||
|
||||
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
|
||||
_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = (
|
||||
r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?"
|
||||
)
|
||||
|
||||
|
||||
def create_rag_data_service_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagDataClientWithOverride,
|
||||
).select_version("v1beta1")
|
||||
|
||||
|
||||
def create_rag_data_service_async_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagDataAsyncClientWithOverride,
|
||||
).select_version("v1beta1")
|
||||
|
||||
|
||||
def create_rag_service_client():
|
||||
return initializer.global_config.create_client(
|
||||
client_class=VertexRagClientWithOverride,
|
||||
).select_version("v1beta1")
|
||||
|
||||
|
||||
def convert_gapic_to_embedding_model_config(
|
||||
gapic_embedding_model_config: GapicRagEmbeddingModelConfig,
|
||||
) -> EmbeddingModelConfig:
|
||||
"""Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig."""
|
||||
embedding_model_config = EmbeddingModelConfig()
|
||||
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
publisher_model = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
|
||||
path,
|
||||
)
|
||||
endpoint = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
|
||||
path,
|
||||
)
|
||||
if publisher_model:
|
||||
embedding_model_config.publisher_model = path
|
||||
if endpoint:
|
||||
embedding_model_config.endpoint = path
|
||||
embedding_model_config.model = (
|
||||
gapic_embedding_model_config.vertex_prediction_endpoint.model
|
||||
)
|
||||
embedding_model_config.model_version_id = (
|
||||
gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id
|
||||
)
|
||||
|
||||
return embedding_model_config
|
||||
|
||||
|
||||
def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("weaviate")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.weaviate.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("rag_managed_db")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.rag_managed_db.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("vertex_feature_store")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.vertex_feature_store.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("pinecone")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.pinecone.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("vertex_vector_search")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.vertex_vector_search.ByteSize() > 0
|
||||
|
||||
|
||||
def _check_rag_embedding_model_config(
|
||||
gapic_vector_db: GapicRagVectorDbConfig,
|
||||
) -> bool:
|
||||
try:
|
||||
return gapic_vector_db.__contains__("rag_embedding_model_config")
|
||||
except AttributeError:
|
||||
return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0
|
||||
|
||||
|
||||
def convert_gapic_to_vector_db(
|
||||
gapic_vector_db: GapicRagVectorDbConfig,
|
||||
) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]:
|
||||
"""Convert Gapic GapicRagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone."""
|
||||
if _check_weaviate(gapic_vector_db):
|
||||
return Weaviate(
|
||||
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
|
||||
collection_name=gapic_vector_db.weaviate.collection_name,
|
||||
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
|
||||
)
|
||||
elif _check_vertex_feature_store(gapic_vector_db):
|
||||
return VertexFeatureStore(
|
||||
resource_name=gapic_vector_db.vertex_feature_store.feature_view_resource_name,
|
||||
)
|
||||
elif _check_pinecone(gapic_vector_db):
|
||||
return Pinecone(
|
||||
index_name=gapic_vector_db.pinecone.index_name,
|
||||
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
|
||||
)
|
||||
elif _check_vertex_vector_search(gapic_vector_db):
|
||||
return VertexVectorSearch(
|
||||
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
|
||||
index=gapic_vector_db.vertex_vector_search.index,
|
||||
)
|
||||
elif _check_rag_managed_db(gapic_vector_db):
|
||||
return RagManagedDb()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_vertex_ai_search_config: VertexAiSearchConfig,
|
||||
) -> VertexAiSearchConfig:
|
||||
"""Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig."""
|
||||
if gapic_vertex_ai_search_config.serving_config:
|
||||
return VertexAiSearchConfig(
|
||||
serving_config=gapic_vertex_ai_search_config.serving_config,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def convert_gapic_to_rag_embedding_model_config(
|
||||
gapic_embedding_model_config: GapicRagEmbeddingModelConfig,
|
||||
) -> RagEmbeddingModelConfig:
|
||||
"""Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig."""
|
||||
embedding_model_config = RagEmbeddingModelConfig()
|
||||
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
|
||||
publisher_model = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
|
||||
path,
|
||||
)
|
||||
endpoint = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
|
||||
path,
|
||||
)
|
||||
if publisher_model:
|
||||
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
|
||||
publisher_model=path
|
||||
)
|
||||
if endpoint:
|
||||
embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint(
|
||||
endpoint=path,
|
||||
model=gapic_embedding_model_config.vertex_prediction_endpoint.model,
|
||||
model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id,
|
||||
)
|
||||
return embedding_model_config
|
||||
|
||||
|
||||
def convert_gapic_to_backend_config(
|
||||
gapic_vector_db: GapicRagVectorDbConfig,
|
||||
) -> RagVectorDbConfig:
|
||||
"""Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb."""
|
||||
vector_config = RagVectorDbConfig()
|
||||
if _check_pinecone(gapic_vector_db):
|
||||
vector_config.vector_db = Pinecone(
|
||||
index_name=gapic_vector_db.pinecone.index_name,
|
||||
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
|
||||
)
|
||||
elif _check_vertex_vector_search(gapic_vector_db):
|
||||
vector_config.vector_db = VertexVectorSearch(
|
||||
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
|
||||
index=gapic_vector_db.vertex_vector_search.index,
|
||||
)
|
||||
elif _check_rag_managed_db(gapic_vector_db):
|
||||
vector_config.vector_db = RagManagedDb()
|
||||
if _check_rag_embedding_model_config(gapic_vector_db):
|
||||
vector_config.rag_embedding_model_config = (
|
||||
convert_gapic_to_rag_embedding_model_config(
|
||||
gapic_vector_db.rag_embedding_model_config
|
||||
)
|
||||
)
|
||||
return vector_config
|
||||
|
||||
|
||||
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
|
||||
"""Convert GapicRagCorpus to RagCorpus."""
|
||||
rag_corpus = RagCorpus(
|
||||
name=gapic_rag_corpus.name,
|
||||
display_name=gapic_rag_corpus.display_name,
|
||||
description=gapic_rag_corpus.description,
|
||||
embedding_model_config=convert_gapic_to_embedding_model_config(
|
||||
gapic_rag_corpus.rag_embedding_model_config
|
||||
),
|
||||
vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
|
||||
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_rag_corpus.vertex_ai_search_config
|
||||
),
|
||||
backend_config=convert_gapic_to_backend_config(
|
||||
gapic_rag_corpus.rag_vector_db_config
|
||||
),
|
||||
)
|
||||
return rag_corpus
|
||||
|
||||
|
||||
def convert_gapic_to_rag_corpus_no_embedding_model_config(
|
||||
gapic_rag_corpus: GapicRagCorpus,
|
||||
) -> RagCorpus:
|
||||
"""Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus."""
|
||||
rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config
|
||||
rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None
|
||||
rag_corpus = RagCorpus(
|
||||
name=gapic_rag_corpus.name,
|
||||
display_name=gapic_rag_corpus.display_name,
|
||||
description=gapic_rag_corpus.description,
|
||||
vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
|
||||
vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config(
|
||||
gapic_rag_corpus.vertex_ai_search_config
|
||||
),
|
||||
backend_config=convert_gapic_to_backend_config(
|
||||
rag_vector_db_config_no_embedding_model_config
|
||||
),
|
||||
)
|
||||
return rag_corpus
|
||||
|
||||
|
||||
def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
|
||||
"""Convert GapicRagFile to RagFile."""
|
||||
rag_file = RagFile(
|
||||
name=gapic_rag_file.name,
|
||||
display_name=gapic_rag_file.display_name,
|
||||
description=gapic_rag_file.description,
|
||||
)
|
||||
return rag_file
|
||||
|
||||
|
||||
def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile:
|
||||
"""Converts a JSON response to a RagFile."""
|
||||
rag_file = RagFile(
|
||||
name=upload_rag_file_response.get("ragFile").get("name"),
|
||||
display_name=upload_rag_file_response.get("ragFile").get("displayName"),
|
||||
description=upload_rag_file_response.get("ragFile").get("description"),
|
||||
)
|
||||
return rag_file
|
||||
|
||||
|
||||
def convert_path_to_resource_id(
|
||||
path: str,
|
||||
) -> Union[str, GoogleDriveSource.ResourceId]:
|
||||
"""Converts a path to a Google Cloud storage uri or GoogleDriveSource.ResourceId."""
|
||||
if path.startswith("gs://"):
|
||||
# Google Cloud Storage source
|
||||
return path
|
||||
elif path.startswith("https://drive.google.com/"):
|
||||
# Google Drive source
|
||||
path_list = path.split("/")
|
||||
if "file" in path_list:
|
||||
index = path_list.index("file") + 2
|
||||
resource_id = path_list[index].split("?")[0]
|
||||
resource_type = GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE
|
||||
elif "folders" in path_list:
|
||||
index = path_list.index("folders") + 1
|
||||
resource_id = path_list[index].split("?")[0]
|
||||
resource_type = (
|
||||
GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER
|
||||
)
|
||||
else:
|
||||
raise ValueError("path %s is not a valid Google Drive url.", path)
|
||||
|
||||
return GoogleDriveSource.ResourceId(
|
||||
resource_id=resource_id,
|
||||
resource_type=resource_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"path must be a Google Cloud Storage uri or a Google Drive url."
|
||||
)
|
||||
|
||||
|
||||
def convert_source_for_rag_import(
|
||||
source: Union[SlackChannelsSource, JiraSource, SharePointSources]
|
||||
) -> Union[GapicSlackSource, GapicJiraSource]:
|
||||
"""Converts a SlackChannelsSource or JiraSource to a GapicSlackSource or GapicJiraSource."""
|
||||
if isinstance(source, SlackChannelsSource):
|
||||
result_source_channels = []
|
||||
for channel in source.channels:
|
||||
api_key = channel.api_key
|
||||
cid = channel.channel_id
|
||||
start_time = channel.start_time
|
||||
end_time = channel.end_time
|
||||
result_channels = GapicSlackSource.SlackChannels(
|
||||
channels=[
|
||||
GapicSlackSource.SlackChannels.SlackChannel(
|
||||
channel_id=cid,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
],
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
)
|
||||
result_source_channels.append(result_channels)
|
||||
return GapicSlackSource(
|
||||
channels=result_source_channels,
|
||||
)
|
||||
elif isinstance(source, JiraSource):
|
||||
result_source_queries = []
|
||||
for query in source.queries:
|
||||
api_key = query.api_key
|
||||
custom_queries = query.custom_queries
|
||||
projects = query.jira_projects
|
||||
email = query.email
|
||||
server_uri = query.server_uri
|
||||
result_query = GapicJiraSource.JiraQueries(
|
||||
custom_queries=custom_queries,
|
||||
projects=projects,
|
||||
email=email,
|
||||
server_uri=server_uri,
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
)
|
||||
result_source_queries.append(result_query)
|
||||
return GapicJiraSource(
|
||||
jira_queries=result_source_queries,
|
||||
)
|
||||
elif isinstance(source, SharePointSources):
|
||||
result_source_share_point_sources = []
|
||||
for share_point_source in source.share_point_sources:
|
||||
sharepoint_folder_path = share_point_source.sharepoint_folder_path
|
||||
sharepoint_folder_id = share_point_source.sharepoint_folder_id
|
||||
drive_name = share_point_source.drive_name
|
||||
drive_id = share_point_source.drive_id
|
||||
client_id = share_point_source.client_id
|
||||
client_secret = share_point_source.client_secret
|
||||
tenant_id = share_point_source.tenant_id
|
||||
sharepoint_site_name = share_point_source.sharepoint_site_name
|
||||
result_share_point_source = GapicSharePointSources.SharePointSource(
|
||||
client_id=client_id,
|
||||
client_secret=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=client_secret
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
sharepoint_site_name=sharepoint_site_name,
|
||||
)
|
||||
if sharepoint_folder_path is not None and sharepoint_folder_id is not None:
|
||||
raise ValueError(
|
||||
"sharepoint_folder_path and sharepoint_folder_id cannot both be set."
|
||||
)
|
||||
elif sharepoint_folder_path is not None:
|
||||
result_share_point_source.sharepoint_folder_path = (
|
||||
sharepoint_folder_path
|
||||
)
|
||||
elif sharepoint_folder_id is not None:
|
||||
result_share_point_source.sharepoint_folder_id = sharepoint_folder_id
|
||||
if drive_name is not None and drive_id is not None:
|
||||
raise ValueError("drive_name and drive_id cannot both be set.")
|
||||
elif drive_name is not None:
|
||||
result_share_point_source.drive_name = drive_name
|
||||
elif drive_id is not None:
|
||||
result_share_point_source.drive_id = drive_id
|
||||
else:
|
||||
raise ValueError("Either drive_name and drive_id must be set.")
|
||||
result_source_share_point_sources.append(result_share_point_source)
|
||||
return GapicSharePointSources(
|
||||
share_point_sources=result_source_share_point_sources,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"source must be a SlackChannelsSource or JiraSource or SharePointSources."
|
||||
)
|
||||
|
||||
|
||||
def prepare_import_files_request(
|
||||
corpus_name: str,
|
||||
paths: Optional[Sequence[str]] = None,
|
||||
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
|
||||
chunk_size: int = 1024,
|
||||
chunk_overlap: int = 200,
|
||||
transformation_config: Optional[TransformationConfig] = None,
|
||||
max_embedding_requests_per_min: int = 1000,
|
||||
use_advanced_pdf_parsing: bool = False,
|
||||
partial_failures_sink: Optional[str] = None,
|
||||
layout_parser: Optional[LayoutParserConfig] = None,
|
||||
llm_parser: Optional[LlmParserConfig] = None,
|
||||
) -> ImportRagFilesRequest:
|
||||
if len(corpus_name.split("/")) != 6:
|
||||
raise ValueError(
|
||||
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
|
||||
)
|
||||
|
||||
rag_file_parsing_config = RagFileParsingConfig(
|
||||
advanced_parser=RagFileParsingConfig.AdvancedParser(
|
||||
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
|
||||
),
|
||||
)
|
||||
if layout_parser is not None:
|
||||
if (
|
||||
re.fullmatch(
|
||||
_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, layout_parser.processor_name
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError(
|
||||
"processor_name must be of the format "
|
||||
"`projects/{project_id}/locations/{location}/processors/{processor_id}`"
|
||||
"or "
|
||||
"`projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`, "
|
||||
f"got {layout_parser.processor_name!r}"
|
||||
)
|
||||
rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser(
|
||||
processor_name=layout_parser.processor_name,
|
||||
max_parsing_requests_per_min=layout_parser.max_parsing_requests_per_min,
|
||||
)
|
||||
if llm_parser is not None:
|
||||
rag_file_parsing_config.llm_parser = RagFileParsingConfig.LlmParser(
|
||||
model_name=llm_parser.model_name
|
||||
)
|
||||
if llm_parser.max_parsing_requests_per_min is not None:
|
||||
rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = (
|
||||
llm_parser.max_parsing_requests_per_min
|
||||
)
|
||||
if llm_parser.custom_parsing_prompt is not None:
|
||||
rag_file_parsing_config.llm_parser.custom_parsing_prompt = (
|
||||
llm_parser.custom_parsing_prompt
|
||||
)
|
||||
|
||||
local_chunk_size = chunk_size
|
||||
local_chunk_overlap = chunk_overlap
|
||||
if transformation_config and transformation_config.chunking_config:
|
||||
local_chunk_size = transformation_config.chunking_config.chunk_size
|
||||
local_chunk_overlap = transformation_config.chunking_config.chunk_overlap
|
||||
|
||||
rag_file_transformation_config = RagFileTransformationConfig(
|
||||
rag_file_chunking_config=RagFileChunkingConfig(
|
||||
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
|
||||
chunk_size=local_chunk_size,
|
||||
chunk_overlap=local_chunk_overlap,
|
||||
),
|
||||
),
|
||||
)
|
||||
import_rag_files_config = ImportRagFilesConfig(
|
||||
rag_file_transformation_config=rag_file_transformation_config,
|
||||
max_embedding_requests_per_min=max_embedding_requests_per_min,
|
||||
rag_file_parsing_config=rag_file_parsing_config,
|
||||
)
|
||||
|
||||
if source is not None:
|
||||
gapic_source = convert_source_for_rag_import(source)
|
||||
if isinstance(gapic_source, GapicSlackSource):
|
||||
import_rag_files_config.slack_source = gapic_source
|
||||
if isinstance(gapic_source, GapicJiraSource):
|
||||
import_rag_files_config.jira_source = gapic_source
|
||||
if isinstance(gapic_source, GapicSharePointSources):
|
||||
import_rag_files_config.share_point_sources = gapic_source
|
||||
else:
|
||||
uris = []
|
||||
resource_ids = []
|
||||
for p in paths:
|
||||
output = convert_path_to_resource_id(p)
|
||||
if isinstance(output, str):
|
||||
uris.append(p)
|
||||
else:
|
||||
resource_ids.append(output)
|
||||
if uris:
|
||||
import_rag_files_config.gcs_source.uris = uris
|
||||
if resource_ids:
|
||||
google_drive_source = GoogleDriveSource(
|
||||
resource_ids=resource_ids,
|
||||
)
|
||||
import_rag_files_config.google_drive_source = google_drive_source
|
||||
|
||||
if partial_failures_sink is not None:
|
||||
if partial_failures_sink.startswith("gs://"):
|
||||
import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = (
|
||||
partial_failures_sink
|
||||
)
|
||||
elif partial_failures_sink.startswith(
|
||||
"bq://"
|
||||
) or partial_failures_sink.startswith("bigquery://"):
|
||||
import_rag_files_config.partial_failure_bigquery_sink.output_uri = (
|
||||
partial_failures_sink
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"if provided, partial_failures_sink must be a GCS path or a BigQuery table."
|
||||
)
|
||||
|
||||
request = ImportRagFilesRequest(
|
||||
parent=corpus_name, import_rag_files_config=import_rag_files_config
|
||||
)
|
||||
return request
|
||||
|
||||
|
||||
def get_corpus_name(
|
||||
name: str,
|
||||
) -> str:
|
||||
if name:
|
||||
client = create_rag_data_service_client()
|
||||
if client.parse_rag_corpus_path(name):
|
||||
return name
|
||||
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
|
||||
return client.rag_corpus_path(
|
||||
project=initializer.global_config.project,
|
||||
location=initializer.global_config.location,
|
||||
rag_corpus=name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def get_file_name(
|
||||
name: str,
|
||||
corpus_name: str,
|
||||
) -> str:
|
||||
client = create_rag_data_service_client()
|
||||
if client.parse_rag_file_path(name):
|
||||
return name
|
||||
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
|
||||
if not corpus_name:
|
||||
raise ValueError(
|
||||
"corpus_name must be provided if name is a `{rag_file}`, not a "
|
||||
"full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). "
|
||||
)
|
||||
return client.rag_file_path(
|
||||
project=initializer.global_config.project,
|
||||
location=initializer.global_config.location,
|
||||
rag_corpus=get_corpus_name(corpus_name),
|
||||
rag_file=name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
|
||||
)
|
||||
|
||||
|
||||
def set_embedding_model_config(
|
||||
embedding_model_config: EmbeddingModelConfig,
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
|
||||
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
|
||||
if (
|
||||
not embedding_model_config.publisher_model
|
||||
and not embedding_model_config.endpoint
|
||||
):
|
||||
raise ValueError("At least one of publisher_model and endpoint must be set.")
|
||||
parent = initializer.global_config.common_location_path(project=None, location=None)
|
||||
|
||||
if embedding_model_config.publisher_model:
|
||||
publisher_model = embedding_model_config.publisher_model
|
||||
full_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
|
||||
publisher_model,
|
||||
)
|
||||
resource_name = re.match(
|
||||
r"^publishers/google/models/(?P<model_id>.+?)$",
|
||||
publisher_model,
|
||||
)
|
||||
if full_resource_name:
|
||||
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
publisher_model
|
||||
)
|
||||
elif resource_name:
|
||||
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
parent + "/" + publisher_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
|
||||
)
|
||||
|
||||
if embedding_model_config.endpoint:
|
||||
endpoint = embedding_model_config.endpoint
|
||||
full_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
|
||||
endpoint,
|
||||
)
|
||||
resource_name = re.match(
|
||||
r"^endpoints/(?P<endpoint>.+?)$",
|
||||
endpoint,
|
||||
)
|
||||
if full_resource_name:
|
||||
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
endpoint
|
||||
)
|
||||
elif resource_name:
|
||||
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
|
||||
parent + "/" + endpoint
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
|
||||
)
|
||||
|
||||
|
||||
def set_vector_db(
|
||||
vector_db: Union[
|
||||
Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb, None
|
||||
],
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
"""Sets the vector db configuration for the rag corpus."""
|
||||
if vector_db is None or isinstance(vector_db, RagManagedDb):
|
||||
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
|
||||
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(),
|
||||
)
|
||||
elif isinstance(vector_db, Weaviate):
|
||||
http_endpoint = vector_db.weaviate_http_endpoint
|
||||
collection_name = vector_db.collection_name
|
||||
api_key = vector_db.api_key
|
||||
|
||||
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
|
||||
weaviate=GapicRagVectorDbConfig.Weaviate(
|
||||
http_endpoint=http_endpoint,
|
||||
collection_name=collection_name,
|
||||
),
|
||||
api_auth=api_auth.ApiAuth(
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(vector_db, VertexFeatureStore):
|
||||
resource_name = vector_db.resource_name
|
||||
|
||||
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
|
||||
vertex_feature_store=GapicRagVectorDbConfig.VertexFeatureStore(
|
||||
feature_view_resource_name=resource_name,
|
||||
),
|
||||
)
|
||||
elif isinstance(vector_db, VertexVectorSearch):
|
||||
index_endpoint = vector_db.index_endpoint
|
||||
index = vector_db.index
|
||||
|
||||
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
|
||||
vertex_vector_search=GapicRagVectorDbConfig.VertexVectorSearch(
|
||||
index_endpoint=index_endpoint,
|
||||
index=index,
|
||||
),
|
||||
)
|
||||
elif isinstance(vector_db, Pinecone):
|
||||
index_name = vector_db.index_name
|
||||
api_key = vector_db.api_key
|
||||
|
||||
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
|
||||
pinecone=GapicRagVectorDbConfig.Pinecone(
|
||||
index_name=index_name,
|
||||
),
|
||||
api_auth=api_auth.ApiAuth(
|
||||
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
|
||||
api_key_secret_version=api_key
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone."
|
||||
)
|
||||
|
||||
|
||||
def set_vertex_ai_search_config(
|
||||
vertex_ai_search_config: VertexAiSearchConfig,
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
if not vertex_ai_search_config.serving_config:
|
||||
raise ValueError("serving_config must be set.")
|
||||
engine_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/engines/(?P<engine>.+?)/servingConfigs/(?P<serving_config>.+?)$",
|
||||
vertex_ai_search_config.serving_config,
|
||||
)
|
||||
data_store_resource_name = re.match(
|
||||
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/collections/(?P<collection>.+?)/dataStores/(?P<data_store>.+?)/servingConfigs/(?P<serving_config>.+?)$",
|
||||
vertex_ai_search_config.serving_config,
|
||||
)
|
||||
if engine_resource_name or data_store_resource_name:
|
||||
rag_corpus.vertex_ai_search_config = GapicVertexAiSearchConfig(
|
||||
serving_config=vertex_ai_search_config.serving_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
|
||||
)
|
||||
|
||||
|
||||
def set_backend_config(
|
||||
backend_config: Optional[
|
||||
Union[
|
||||
RagVectorDbConfig,
|
||||
None,
|
||||
]
|
||||
],
|
||||
rag_corpus: GapicRagCorpus,
|
||||
) -> None:
|
||||
"""Sets the vector db configuration for the rag corpus."""
|
||||
if backend_config is None:
|
||||
return
|
||||
|
||||
if backend_config.vector_db is not None:
|
||||
vector_config = backend_config.vector_db
|
||||
if vector_config is None or isinstance(vector_config, RagManagedDb):
|
||||
rag_corpus.vector_db_config.rag_managed_db.CopyFrom(
|
||||
GapicRagVectorDbConfig.RagManagedDb()
|
||||
)
|
||||
elif isinstance(vector_config, VertexVectorSearch):
|
||||
index_endpoint = vector_config.index_endpoint
|
||||
index = vector_config.index
|
||||
|
||||
rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = (
|
||||
index_endpoint
|
||||
)
|
||||
rag_corpus.vector_db_config.vertex_vector_search.index = index
|
||||
elif isinstance(vector_config, Pinecone):
|
||||
index_name = vector_config.index_name
|
||||
api_key = vector_config.api_key
|
||||
|
||||
rag_corpus.vector_db_config.pinecone.index_name = index_name
|
||||
rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = (
|
||||
api_key
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"backend_config must be a VertexFeatureStore,"
|
||||
"RagManagedDb, or Pinecone."
|
||||
)
|
||||
if backend_config.rag_embedding_model_config:
|
||||
set_embedding_model_config(
|
||||
backend_config.rag_embedding_model_config, rag_corpus
|
||||
)
|
||||
|
||||
|
||||
def convert_gapic_to_rag_engine_config(
|
||||
response: GapicRagEngineConfig,
|
||||
) -> RagEngineConfig:
|
||||
"""Converts a GapicRagEngineConfig to a RagEngineConfig."""
|
||||
rag_managed_db_config = RagManagedDbConfig()
|
||||
# If future fields are added with similar names, beware that __contains__
|
||||
# may match them.
|
||||
if response.rag_managed_db_config.__contains__("enterprise"):
|
||||
rag_managed_db_config.tier = Enterprise()
|
||||
elif response.rag_managed_db_config.__contains__("basic"):
|
||||
rag_managed_db_config.tier = Basic()
|
||||
else:
|
||||
raise ValueError("At least one of rag_managed_db_config must be set.")
|
||||
return RagEngineConfig(
|
||||
name=response.name,
|
||||
rag_managed_db_config=rag_managed_db_config,
|
||||
)
|
||||
|
||||
|
||||
def convert_rag_engine_config_to_gapic(
|
||||
rag_engine_config: RagEngineConfig,
|
||||
) -> GapicRagEngineConfig:
|
||||
"""Converts a RagEngineConfig to a GapicRagEngineConfig."""
|
||||
rag_managed_db_config = GapicRagManagedDbConfig()
|
||||
if (
|
||||
rag_engine_config.rag_managed_db_config is None
|
||||
or rag_engine_config.rag_managed_db_config.tier is None
|
||||
):
|
||||
rag_managed_db_config = GapicRagManagedDbConfig(
|
||||
enterprise=GapicRagManagedDbConfig.Enterprise()
|
||||
)
|
||||
else:
|
||||
if isinstance(rag_engine_config.rag_managed_db_config.tier, Enterprise):
|
||||
rag_managed_db_config.enterprise = GapicRagManagedDbConfig.Enterprise()
|
||||
elif isinstance(rag_engine_config.rag_managed_db_config.tier, Basic):
|
||||
rag_managed_db_config.basic = GapicRagManagedDbConfig.Basic()
|
||||
return GapicRagEngineConfig(
|
||||
name=rag_engine_config.name,
|
||||
rag_managed_db_config=rag_managed_db_config,
|
||||
)
|
||||
@@ -0,0 +1,576 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
DEPRECATION_DATE = "June 2025"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagFile:
|
||||
"""RAG file (output only).
|
||||
|
||||
Attributes:
|
||||
name: Generated resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}/ragFiles/{rag_file}``
|
||||
display_name: Display name that was configured at client side.
|
||||
description: The description of the RagFile.
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EmbeddingModelConfig:
|
||||
"""EmbeddingModelConfig.
|
||||
|
||||
The representation of the embedding model config. Users input a 1P embedding
|
||||
model as a Publisher model resource, or a 1P fine tuned embedding model
|
||||
as an Endpoint resource.
|
||||
|
||||
Attributes:
|
||||
publisher_model: 1P publisher model resource name. Format:
|
||||
``publishers/google/models/{model}`` or
|
||||
``projects/{project}/locations/{location}/publishers/google/models/{model}``
|
||||
endpoint: 1P fine tuned embedding model resource name. Format:
|
||||
``endpoints/{endpoint}`` or
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``.
|
||||
model:
|
||||
Output only. The resource name of the model that is deployed
|
||||
on the endpoint. Present only when the endpoint is not a
|
||||
publisher model. Pattern:
|
||||
``projects/{project}/locations/{location}/models/{model}``
|
||||
model_version_id:
|
||||
Output only. Version ID of the model that is
|
||||
deployed on the endpoint. Present only when the
|
||||
endpoint is not a publisher model.
|
||||
"""
|
||||
|
||||
publisher_model: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
model_version_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexPredictionEndpoint:
|
||||
"""VertexPredictionEndpoint.
|
||||
|
||||
Attributes:
|
||||
publisher_model: 1P publisher model resource name. Format:
|
||||
``publishers/google/models/{model}`` or
|
||||
``projects/{project}/locations/{location}/publishers/google/models/{model}``
|
||||
endpoint: 1P fine tuned embedding model resource name. Format:
|
||||
``endpoints/{endpoint}`` or
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``.
|
||||
model:
|
||||
Output only. The resource name of the model that is deployed
|
||||
on the endpoint. Present only when the endpoint is not a
|
||||
publisher model. Pattern:
|
||||
``projects/{project}/locations/{location}/models/{model}``
|
||||
model_version_id:
|
||||
Output only. Version ID of the model that is
|
||||
deployed on the endpoint. Present only when the
|
||||
endpoint is not a publisher model.
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
publisher_model: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
model_version_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagEmbeddingModelConfig:
|
||||
"""RagEmbeddingModelConfig.
|
||||
|
||||
Attributes:
|
||||
vertex_prediction_endpoint: The Vertex AI Prediction Endpoint resource
|
||||
name. Format:
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``
|
||||
"""
|
||||
|
||||
vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Weaviate:
|
||||
"""Weaviate.
|
||||
|
||||
Attributes:
|
||||
weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint
|
||||
collection_name: The corresponding Weaviate collection this corpus maps to
|
||||
api_key: The SecretManager resource name for the Weaviate DB API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
"""
|
||||
|
||||
weaviate_http_endpoint: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexFeatureStore:
|
||||
"""VertexFeatureStore.
|
||||
|
||||
Attributes:
|
||||
resource_name: The resource name of the FeatureView. Format:
|
||||
``projects/{project}/locations/{location}/featureOnlineStores/
|
||||
{feature_online_store}/featureViews/{feature_view}``
|
||||
"""
|
||||
|
||||
resource_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexVectorSearch:
|
||||
"""VertexVectorSearch.
|
||||
|
||||
Attributes:
|
||||
index_endpoint (str):
|
||||
The resource name of the Index Endpoint. Format:
|
||||
``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}``
|
||||
index (str):
|
||||
The resource name of the Index. Format:
|
||||
``projects/{project}/locations/{location}/indexes/{index}``
|
||||
"""
|
||||
|
||||
index_endpoint: Optional[str] = None
|
||||
index: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagManagedDb:
|
||||
"""RagManagedDb."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Pinecone:
|
||||
"""Pinecone.
|
||||
|
||||
Attributes:
|
||||
index_name: The Pinecone index name.
|
||||
api_key: The SecretManager resource name for the Pinecone DB API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
"""
|
||||
|
||||
index_name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VertexAiSearchConfig:
|
||||
"""VertexAiSearchConfig.
|
||||
|
||||
Attributes:
|
||||
serving_config: The resource name of the Vertex AI Search serving config.
|
||||
Format:
|
||||
``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}``
|
||||
or
|
||||
``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}``
|
||||
"""
|
||||
|
||||
serving_config: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagVectorDbConfig:
|
||||
"""RagVectorDbConfig.
|
||||
|
||||
Attributes:
|
||||
vector_db: Can be one of the following: Weaviate, VertexFeatureStore,
|
||||
VertexVectorSearch, Pinecone, RagManagedDb.
|
||||
rag_embedding_model_config: The embedding model config of the Vector DB.
|
||||
"""
|
||||
|
||||
vector_db: Optional[
|
||||
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
|
||||
] = None
|
||||
rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagCorpus:
|
||||
"""RAG corpus(output only).
|
||||
|
||||
Attributes:
|
||||
name: Generated resource name. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
||||
display_name: Display name that was configured at client side.
|
||||
description: The description of the RagCorpus.
|
||||
embedding_model_config: The embedding model config of the RagCorpus.
|
||||
Note: Deprecated. Use backend_config instead.
|
||||
vector_db: The Vector DB of the RagCorpus.
|
||||
Note: Deprecated. Use backend_config instead.
|
||||
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
|
||||
backend_config: The backend config of the RagCorpus. It can specify a
|
||||
Vector DB and/or the embedding model config.
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
embedding_model_config: Optional[EmbeddingModelConfig] = None
|
||||
vector_db: Optional[
|
||||
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
|
||||
] = None
|
||||
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None
|
||||
backend_config: Optional[RagVectorDbConfig] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagResource:
|
||||
"""RagResource.
|
||||
|
||||
The representation of the rag source. It can be used to specify corpus only
|
||||
or ragfiles. Currently only support one corpus or multiple files from one
|
||||
corpus. In the future we may open up multiple corpora support.
|
||||
|
||||
Attributes:
|
||||
rag_corpus: A Rag corpus resource name or corpus id. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
||||
or ``{rag_corpus_id}``.
|
||||
rag_files_id: List of Rag file resource name or file ids in the same corpus. Format:
|
||||
``{rag_file}``.
|
||||
"""
|
||||
|
||||
rag_corpus: Optional[str] = None
|
||||
rag_file_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SlackChannel:
|
||||
"""SlackChannel.
|
||||
|
||||
Attributes:
|
||||
channel_id: The Slack channel ID.
|
||||
api_key: The SecretManager resource name for the Slack API token. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
See: https://api.slack.com/tutorials/tracks/getting-a-token.
|
||||
start_time: The starting timestamp for messages to import.
|
||||
end_time: The ending timestamp for messages to import.
|
||||
"""
|
||||
|
||||
channel_id: str
|
||||
api_key: str
|
||||
start_time: Optional[timestamp_pb2.Timestamp] = None
|
||||
end_time: Optional[timestamp_pb2.Timestamp] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SlackChannelsSource:
|
||||
"""SlackChannelsSource.
|
||||
|
||||
Attributes:
|
||||
channels: The Slack channels.
|
||||
"""
|
||||
|
||||
channels: Sequence[SlackChannel]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JiraQuery:
|
||||
"""JiraQuery.
|
||||
|
||||
Attributes:
|
||||
email: The Jira email address.
|
||||
jira_projects: A list of Jira projects to import in their entirety.
|
||||
custom_queries: A list of custom JQL Jira queries to import.
|
||||
api_key: The SecretManager version resource name for Jira API access. Format:
|
||||
``projects/{project}/secrets/{secret}/versions/{version}``
|
||||
See: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/
|
||||
server_uri: The Jira server URI. Format:
|
||||
``{server}.atlassian.net``
|
||||
"""
|
||||
|
||||
email: str
|
||||
jira_projects: Sequence[str]
|
||||
custom_queries: Sequence[str]
|
||||
api_key: str
|
||||
server_uri: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JiraSource:
|
||||
"""JiraSource.
|
||||
|
||||
Attributes:
|
||||
queries: The Jira queries.
|
||||
"""
|
||||
|
||||
queries: Sequence[JiraQuery]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SharePointSource:
|
||||
"""SharePointSource.
|
||||
|
||||
Attributes:
|
||||
sharepoint_folder_path: The path of the SharePoint folder to download
|
||||
from.
|
||||
sharepoint_folder_id: The ID of the SharePoint folder to download
|
||||
from.
|
||||
drive_name: The name of the drive to download from.
|
||||
drive_id: The ID of the drive to download from.
|
||||
client_id: The Application ID for the app registered in
|
||||
Microsoft Azure Portal. The application must
|
||||
also be configured with MS Graph permissions
|
||||
"Files.ReadAll", "Sites.ReadAll" and
|
||||
BrowserSiteLists.Read.All.
|
||||
client_secret: The application secret for the app registered
|
||||
in Azure.
|
||||
tenant_id: Unique identifier of the Azure Active
|
||||
Directory Instance.
|
||||
sharepoint_site_name: The name of the SharePoint site to download
|
||||
from. This can be the site name or the site id.
|
||||
"""
|
||||
|
||||
sharepoint_folder_path: Optional[str] = None
|
||||
sharepoint_folder_id: Optional[str] = None
|
||||
drive_name: Optional[str] = None
|
||||
drive_id: Optional[str] = None
|
||||
client_id: str = None
|
||||
client_secret: str = None
|
||||
tenant_id: str = None
|
||||
sharepoint_site_name: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SharePointSources:
|
||||
"""SharePointSources.
|
||||
|
||||
Attributes:
|
||||
share_point_sources: The SharePoint sources.
|
||||
"""
|
||||
|
||||
share_point_sources: Sequence[SharePointSource]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Filter:
|
||||
"""Filter.
|
||||
|
||||
Attributes:
|
||||
vector_distance_threshold: Only returns contexts with vector
|
||||
distance smaller than the threshold.
|
||||
vector_similarity_threshold: Only returns contexts with vector
|
||||
similarity larger than the threshold.
|
||||
metadata_filter: String for metadata filtering.
|
||||
"""
|
||||
|
||||
vector_distance_threshold: Optional[float] = None
|
||||
vector_similarity_threshold: Optional[float] = None
|
||||
metadata_filter: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HybridSearch:
|
||||
"""HybridSearch.
|
||||
|
||||
Attributes:
|
||||
alpha: Alpha value controls the weight between dense and
|
||||
sparse vector search results. The range is [0, 1], while 0
|
||||
means sparse vector search only and 1 means dense vector
|
||||
search only. The default value is 0.5 which balances sparse
|
||||
and dense vector search equally.
|
||||
"""
|
||||
|
||||
alpha: Optional[float] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LlmRanker:
|
||||
"""LlmRanker.
|
||||
|
||||
Attributes:
|
||||
model_name: The model name used for ranking. Only Gemini models are
|
||||
supported for now.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RankService:
|
||||
"""RankService.
|
||||
|
||||
Attributes:
|
||||
model_name: The model name of the rank service. Format:
|
||||
``semantic-ranker-512@latest``
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Ranking:
|
||||
"""Ranking.
|
||||
|
||||
Attributes:
|
||||
rank_service: (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.RankService)
|
||||
Config for Rank Service.
|
||||
llm_ranker (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.LlmRanker):
|
||||
Config for LlmRanker.
|
||||
"""
|
||||
|
||||
rank_service: Optional[RankService] = None
|
||||
llm_ranker: Optional[LlmRanker] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagRetrievalConfig:
|
||||
"""RagRetrievalConfig.
|
||||
|
||||
Attributes:
|
||||
top_k: The number of contexts to retrieve.
|
||||
filter: Config for filters.
|
||||
hybrid_search (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.HybridSearch):
|
||||
Config for Hybrid Search.
|
||||
ranking (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking):
|
||||
Config for ranking and reranking.
|
||||
"""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
filter: Optional[Filter] = None
|
||||
hybrid_search: Optional[HybridSearch] = None
|
||||
ranking: Optional[Ranking] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChunkingConfig:
|
||||
"""ChunkingConfig.
|
||||
|
||||
Attributes:
|
||||
chunk_size: The size of each chunk.
|
||||
chunk_overlap: The size of the overlap between chunks.
|
||||
"""
|
||||
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TransformationConfig:
|
||||
"""TransformationConfig.
|
||||
|
||||
Attributes:
|
||||
chunking_config: The chunking config.
|
||||
"""
|
||||
|
||||
chunking_config: Optional[ChunkingConfig] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LayoutParserConfig:
|
||||
"""Configuration for the Document AI Layout Parser Processor.
|
||||
|
||||
Attributes:
|
||||
processor_name (str):
|
||||
The full resource name of a Document AI processor or processor
|
||||
version. The processor must have type `LAYOUT_PARSER_PROCESSOR`.
|
||||
Format:
|
||||
- `projects/{project_id}/locations/{location}/processors/{processor_id}`
|
||||
- `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`
|
||||
max_parsing_requests_per_min (int):
|
||||
The maximum number of requests the job is allowed to make to the
|
||||
Document AI processor per minute. Consult
|
||||
https://cloud.google.com/document-ai/quotas and the Quota page for
|
||||
your project to set an appropriate value here. If unspecified, a
|
||||
default value of 120 QPM will be used.
|
||||
"""
|
||||
|
||||
processor_name: str
|
||||
max_parsing_requests_per_min: Optional[int] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LlmParserConfig:
|
||||
"""Configuration for the Document AI Layout Parser Processor.
|
||||
|
||||
Attributes:
|
||||
model_name (str):
|
||||
The full resource name of a Vertex AI model. Format:
|
||||
- `projects/{project_id}/locations/{location}/publishers/google/models/{model_id}`
|
||||
- `projects/{project_id}/locations/{location}/models/{model_id}`
|
||||
max_parsing_requests_per_min (int):
|
||||
The maximum number of requests the job is allowed to make to the
|
||||
Vertex AI model per minute. Consult
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/quotas and
|
||||
the Quota page for your project to set an appropriate value here.
|
||||
If unspecified, a default value of 5000 QPM will be used.
|
||||
custom_parsing_prompt (str):
|
||||
A custom prompt to use for parsing.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
max_parsing_requests_per_min: Optional[int] = None
|
||||
custom_parsing_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Enterprise:
|
||||
"""Enterprise tier offers production grade performance along with
|
||||
|
||||
autoscaling functionality. It is suitable for customers with large
|
||||
amounts of data or performance sensitive workloads.
|
||||
|
||||
NOTE: This is the default tier if not explicitly chosen.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Basic:
|
||||
"""Basic tier is a cost-effective and low compute tier suitable for the following cases:
|
||||
|
||||
* Experimenting with RagManagedDb.
|
||||
* Small data size.
|
||||
* Latency insensitive workload.
|
||||
* Only using RAG Engine with external vector DBs.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagManagedDbConfig:
|
||||
"""RagManagedDbConfig.
|
||||
|
||||
The config of the RagManagedDb used by RagEngine.
|
||||
|
||||
Attributes:
|
||||
tier: The tier of the RagManagedDb. The default tier is Enterprise.
|
||||
"""
|
||||
|
||||
tier: Optional[Union[Enterprise, Basic]] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagEngineConfig:
|
||||
"""RagEngineConfig.
|
||||
|
||||
Attributes:
|
||||
name: Generated resource name for singleton resource. Format:
|
||||
``projects/{project}/locations/{location}/ragEngineConfig``
|
||||
rag_managed_db_config: The config of the RagManagedDb used by RagEngine.
|
||||
The default tier is Enterprise.
|
||||
"""
|
||||
|
||||
name: str
|
||||
rag_managed_db_config: Optional[RagManagedDbConfig] = None
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with reasoning engines."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.reasoning_engines._reasoning_engines import (
|
||||
Queryable,
|
||||
ReasoningEngine,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.adk import (
|
||||
AdkApp,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.ag2 import (
|
||||
AG2Agent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.langchain import (
|
||||
LangchainAgent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.langgraph import (
|
||||
LanggraphAgent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.llama_index import (
|
||||
LlamaIndexQueryPipelineAgent,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"AdkApp",
|
||||
"AG2Agent",
|
||||
"LangchainAgent",
|
||||
"LanggraphAgent",
|
||||
"LlamaIndexQueryPipelineAgent",
|
||||
"Queryable",
|
||||
"ReasoningEngine",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,651 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from google.adk.events.event import Event
|
||||
|
||||
Event = Event
|
||||
except (ImportError, AttributeError):
|
||||
Event = Any
|
||||
|
||||
try:
|
||||
from google.adk.agents import BaseAgent
|
||||
|
||||
BaseAgent = BaseAgent
|
||||
except (ImportError, AttributeError):
|
||||
BaseAgent = Any
|
||||
|
||||
try:
|
||||
from google.adk.sessions import BaseSessionService
|
||||
|
||||
BaseSessionService = BaseSessionService
|
||||
except (ImportError, AttributeError):
|
||||
BaseSessionService = Any
|
||||
|
||||
try:
|
||||
from google.adk.artifacts import BaseArtifactService
|
||||
|
||||
BaseArtifactService = BaseArtifactService
|
||||
except (ImportError, AttributeError):
|
||||
BaseArtifactService = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
_DEFAULT_APP_NAME = "default-app-name"
|
||||
_DEFAULT_USER_ID = "default-user-id"
|
||||
|
||||
|
||||
class _ArtifactVersion:
|
||||
def __init__(self, **kwargs):
|
||||
self.version: Optional[str] = kwargs.get("version")
|
||||
self.data = kwargs.get("data")
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.version:
|
||||
result["version"] = self.version
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
return result
|
||||
|
||||
|
||||
class _Artifact:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_name: Optional[str] = kwargs.get("file_name")
|
||||
self.versions: List[_ArtifactVersion] = kwargs.get("versions")
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.file_name:
|
||||
result["file_name"] = self.file_name
|
||||
if self.versions:
|
||||
result["versions"] = [version.dump() for version in self.versions]
|
||||
return result
|
||||
|
||||
|
||||
class _Authorization:
|
||||
def __init__(self, **kwargs):
|
||||
self.access_token: Optional[str] = kwargs.get("access_token") or kwargs.get(
|
||||
"accessToken"
|
||||
)
|
||||
|
||||
|
||||
class _StreamRunRequest:
|
||||
"""Request object for `streaming_agent_run_with_events` method."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types
|
||||
|
||||
self.message: Optional[types.Content] = kwargs.get("message")
|
||||
# The new message to be processed by the agent.
|
||||
|
||||
self.events: Optional[List[Event]] = kwargs.get("events")
|
||||
# List of preceding events happened in the same session.
|
||||
|
||||
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
|
||||
# List of artifacts belonging to the session.
|
||||
|
||||
self.authorizations: Dict[str, _Authorization] = kwargs.get(
|
||||
"authorizations", {}
|
||||
)
|
||||
# The authorizations of the user, keyed by authorization ID.
|
||||
|
||||
self.user_id: Optional[str] = kwargs.get("user_id", _DEFAULT_USER_ID)
|
||||
# The user ID.
|
||||
|
||||
|
||||
class _StreamingRunResponse:
|
||||
"""Response object for `streaming_agent_run_with_events` method.
|
||||
|
||||
It contains the generated events together with the belonging artifacts.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.events: Optional[List["Event"]] = kwargs.get("events")
|
||||
# List of generated events.
|
||||
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
|
||||
# List of artifacts belonging to the session.
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.events:
|
||||
result["events"] = []
|
||||
for event in self.events:
|
||||
event_dict = event.model_dump(exclude_none=True)
|
||||
event_dict["invocation_id"] = event_dict.get("invocation_id", "")
|
||||
result["events"].append(event_dict)
|
||||
if self.artifacts:
|
||||
result["artifacts"] = [artifact.dump() for artifact in self.artifacts]
|
||||
return result
|
||||
|
||||
|
||||
def _default_instrumentor_builder(project_id: str):
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=project_id,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(project_id),
|
||||
),
|
||||
)
|
||||
span_processor = opentelemetry_sdk_trace.export.BatchSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
tracer_provider = opentelemetry.trace.get_tracer_provider()
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
return None
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, "
|
||||
"`opentelemetry-exporter-gcp-trace`) for tracing have been installed"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class AdkApp:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: "BaseAgent",
|
||||
enable_tracing: bool = False,
|
||||
session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None,
|
||||
artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None,
|
||||
env_vars: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""An ADK Application."""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._tmpl_attrs: Dict[str, Any] = {
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"agent": agent,
|
||||
"enable_tracing": enable_tracing,
|
||||
"session_service_builder": session_service_builder,
|
||||
"artifact_service_builder": artifact_service_builder,
|
||||
"app_name": _DEFAULT_APP_NAME,
|
||||
"env_vars": env_vars or {},
|
||||
}
|
||||
|
||||
def _init_session(
|
||||
self,
|
||||
session_service: "BaseSessionService",
|
||||
artifact_service: "BaseArtifactService",
|
||||
request: _StreamRunRequest,
|
||||
):
|
||||
"""Initializes the session, and returns the session id."""
|
||||
from google.adk.events.event import Event
|
||||
import random
|
||||
|
||||
session_state = None
|
||||
if request.authorizations:
|
||||
session_state = {}
|
||||
for auth_id, auth in request.authorizations.items():
|
||||
auth = _Authorization(**auth)
|
||||
session_state[f"temp:{auth_id}"] = auth.access_token
|
||||
|
||||
session_id = f"temp_session_{random.randbytes(8).hex()}"
|
||||
session = session_service.create_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session_id,
|
||||
state=session_state,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError("Create session failed.")
|
||||
if request.events:
|
||||
for event in request.events:
|
||||
session_service.append_event(session, Event(**event))
|
||||
if request.artifacts:
|
||||
for artifact in request.artifacts:
|
||||
artifact = _Artifact(**artifact)
|
||||
for version_data in sorted(
|
||||
artifact.versions, key=lambda x: x["version"]
|
||||
):
|
||||
version_data = _ArtifactVersion(**version_data)
|
||||
saved_version = artifact_service.save_artifact(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact.file_name,
|
||||
artifact=version_data.data,
|
||||
)
|
||||
if saved_version != version_data.version:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.debug(
|
||||
"Artifact '%s' saved at version %s instead of %s",
|
||||
artifact.file_name,
|
||||
saved_version,
|
||||
version_data.version,
|
||||
)
|
||||
return session
|
||||
|
||||
def _convert_response_events(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
events: List["Event"],
|
||||
artifact_service: Optional["BaseArtifactService"],
|
||||
) -> _StreamingRunResponse:
|
||||
"""Converts the events to the streaming run response object."""
|
||||
import collections
|
||||
|
||||
result = _StreamingRunResponse(events=events, artifacts=[])
|
||||
|
||||
# Save the generated artifacts into the result object.
|
||||
artifact_versions = collections.defaultdict(list)
|
||||
for event in events:
|
||||
if event.actions and event.actions.artifact_delta:
|
||||
for key, version in event.actions.artifact_delta.items():
|
||||
artifact_versions[key].append(version)
|
||||
|
||||
for key, versions in artifact_versions.items():
|
||||
result.artifacts.append(
|
||||
_Artifact(
|
||||
file_name=key,
|
||||
versions=[
|
||||
_ArtifactVersion(
|
||||
version=version,
|
||||
data=artifact_service.load_artifact(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=key,
|
||||
version=version,
|
||||
),
|
||||
)
|
||||
for version in versions
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return result.dump()
|
||||
|
||||
def clone(self):
|
||||
"""Returns a clone of the ADK application."""
|
||||
import copy
|
||||
|
||||
return AdkApp(
|
||||
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
|
||||
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
|
||||
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
|
||||
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
|
||||
env_vars=self._tmpl_attrs.get("env_vars"),
|
||||
)
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the ADK application."""
|
||||
import os
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.artifacts.in_memory_artifact_service import (
|
||||
InMemoryArtifactService,
|
||||
)
|
||||
|
||||
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
|
||||
project = self._tmpl_attrs.get("project")
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"] = project
|
||||
location = self._tmpl_attrs.get("location")
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"] = location
|
||||
if self._tmpl_attrs.get("enable_tracing"):
|
||||
self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
|
||||
project_id=project
|
||||
)
|
||||
for key, value in self._tmpl_attrs.get("env_vars").items():
|
||||
os.environ[key] = value
|
||||
|
||||
artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder")
|
||||
if artifact_service_builder:
|
||||
self._tmpl_attrs["artifact_service"] = artifact_service_builder()
|
||||
else:
|
||||
self._tmpl_attrs["artifact_service"] = InMemoryArtifactService()
|
||||
|
||||
session_service_builder = self._tmpl_attrs.get("session_service_builder")
|
||||
if session_service_builder:
|
||||
self._tmpl_attrs["session_service"] = session_service_builder()
|
||||
elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ:
|
||||
from google.adk.sessions.vertex_ai_session_service import (
|
||||
VertexAiSessionService,
|
||||
)
|
||||
|
||||
self._tmpl_attrs["session_service"] = VertexAiSessionService(
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
self._tmpl_attrs["app_name"] = os.environ.get(
|
||||
"GOOGLE_CLOUD_AGENT_ENGINE_ID",
|
||||
self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
else:
|
||||
self._tmpl_attrs["session_service"] = InMemorySessionService()
|
||||
|
||||
self._tmpl_attrs["runner"] = Runner(
|
||||
agent=self._tmpl_attrs.get("agent"),
|
||||
session_service=self._tmpl_attrs.get("session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("artifact_service"),
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
|
||||
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
|
||||
self._tmpl_attrs["in_memory_runner"] = Runner(
|
||||
agent=self._tmpl_attrs.get("agent"),
|
||||
session_service=self._tmpl_attrs.get("in_memory_session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
user_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Streams responses from the ADK application in response to a message.
|
||||
|
||||
Args:
|
||||
message (str):
|
||||
Required. The message to stream responses for.
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Optional. The ID of the session. If not provided, a new
|
||||
session will be created for the user.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
runner.
|
||||
|
||||
Yields:
|
||||
The output of querying the ADK application.
|
||||
"""
|
||||
from google.genai import types
|
||||
|
||||
content = types.Content(role="user", parts=[types.Part(text=message)])
|
||||
if not self._tmpl_attrs.get("runner"):
|
||||
self.set_up()
|
||||
if not session_id:
|
||||
session = self.create_session(user_id=user_id)
|
||||
session_id = session.id
|
||||
for event in self._tmpl_attrs.get("runner").run(
|
||||
user_id=user_id, session_id=session_id, new_message=content, **kwargs
|
||||
):
|
||||
yield event.model_dump(exclude_none=True)
|
||||
|
||||
def streaming_agent_run_with_events(self, request_json: str):
|
||||
import json
|
||||
from google.genai import types
|
||||
|
||||
request = _StreamRunRequest(**json.loads(request_json))
|
||||
if not self._tmpl_attrs.get("in_memory_runner"):
|
||||
self.set_up()
|
||||
if not self._tmpl_attrs.get("artifact_service"):
|
||||
self.set_up()
|
||||
# Prepare the in-memory session.
|
||||
if not self._tmpl_attrs.get("in_memory_artifact_service"):
|
||||
self.set_up()
|
||||
if not self._tmpl_attrs.get("in_memory_session_service"):
|
||||
self.set_up()
|
||||
session = self._init_session(
|
||||
session_service=self._tmpl_attrs.get("in_memory_session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
request=request,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError("Session initialization failed.")
|
||||
# Run the agent.
|
||||
for event in self._tmpl_attrs.get("in_memory_runner").run(
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(**request.message),
|
||||
):
|
||||
yield self._convert_response_events(
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
events=[event],
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
)
|
||||
self._tmpl_attrs.get("in_memory_session_service").delete_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get a session for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Required. The ID of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
Session: The session instance (if any). It returns None if the
|
||||
session is not found.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the session is not found.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
session = self._tmpl_attrs.get("session_service").get_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
**kwargs,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError(
|
||||
"Session not found. Please create it using .create_session()"
|
||||
)
|
||||
return session
|
||||
|
||||
def list_sessions(self, *, user_id: str, **kwargs):
|
||||
"""List sessions for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: The list of sessions.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
return self._tmpl_attrs.get("session_service").list_sessions(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Creates a new session.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Optional. The ID of the session. If not provided, an ID
|
||||
will be be generated for the session.
|
||||
state (dict[str, Any]):
|
||||
Optional. The initial state of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
Session: The newly created session instance.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
session = self._tmpl_attrs.get("session_service").create_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
state=state,
|
||||
**kwargs,
|
||||
)
|
||||
return session
|
||||
|
||||
def delete_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Deletes a session for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Required. The ID of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
self._tmpl_attrs.get("session_service").delete_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def register_operations(self) -> Dict[str, List[str]]:
|
||||
"""Registers the operations of the ADK application."""
|
||||
return {
|
||||
"": [
|
||||
"get_session",
|
||||
"list_sessions",
|
||||
"create_session",
|
||||
"delete_session",
|
||||
],
|
||||
"stream": ["stream_query", "streaming_agent_run_with_events"],
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from autogen import agentchat
|
||||
|
||||
ConversableAgent = agentchat.ConversableAgent
|
||||
ChatResult = agentchat.ChatResult
|
||||
except ImportError:
|
||||
ConversableAgent = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _prepare_runnable_kwargs(
|
||||
runnable_kwargs: Mapping[str, Any],
|
||||
system_instruction: str,
|
||||
runnable_name: str,
|
||||
llm_config: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
|
||||
if runnable_kwargs is None:
|
||||
runnable_kwargs = {}
|
||||
|
||||
if (
|
||||
"human_input_mode" in runnable_kwargs
|
||||
and runnable_kwargs["human_input_mode"] != "NEVER"
|
||||
):
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
f"human_input_mode={runnable_kwargs['human_input_mode']}"
|
||||
"is not supported. Will be enforced to 'NEVER'."
|
||||
)
|
||||
runnable_kwargs["human_input_mode"] = "NEVER"
|
||||
|
||||
if "system_message" not in runnable_kwargs and system_instruction:
|
||||
runnable_kwargs["system_message"] = system_instruction
|
||||
|
||||
if "name" not in runnable_kwargs:
|
||||
runnable_kwargs["name"] = runnable_name
|
||||
|
||||
if "llm_config" not in runnable_kwargs:
|
||||
runnable_kwargs["llm_config"] = llm_config
|
||||
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
**runnable_kwargs: Any,
|
||||
) -> "ConversableAgent":
|
||||
from autogen import agentchat
|
||||
|
||||
return agentchat.ConversableAgent(**runnable_kwargs)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing AG2 tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence[Callable[..., Any]]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple AG2Agents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new AG2Agent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class AG2Agent:
|
||||
"""An AG2 Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/ag2
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runnable_name: str,
|
||||
*,
|
||||
api_type: Optional[str] = None,
|
||||
llm_config: Optional[Mapping[str, Any]] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
|
||||
tools: Optional[Sequence[Callable[..., Any]]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the AG2 Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
# runnable_builder
|
||||
runnable = runnable_builder(
|
||||
llm_config=llm_config,
|
||||
system_message=system_instruction,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# llm_config
|
||||
llm_config = {
|
||||
"config_list": [{
|
||||
"project_id": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"model": "gemini-1.0-pro-001",
|
||||
"api_type": "google",
|
||||
}]
|
||||
}
|
||||
|
||||
# runnable_builder
|
||||
runnable = ConversableAgent(
|
||||
llm_config=llm_config,
|
||||
name="Default AG2 Agent"
|
||||
system_message="You are a helpful AI Assistant.",
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
```
|
||||
|
||||
By default, if `llm_config` is not specified, a default configuration
|
||||
will be created using the provided `model` and `api_type`.
|
||||
|
||||
If `runnable_builder` is not specified, a default runnable builder will
|
||||
be used, configured with the `system_instruction`, `runnable_name` and
|
||||
`runnable_kwargs`.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
runnable_name (str):
|
||||
Required. The name of the runnable.
|
||||
This name is used as the default `runnable_kwargs["name"]`
|
||||
unless `runnable_kwargs` already contains a "name", in which
|
||||
case the provided `runnable_kwargs["name"]` will be used.
|
||||
api_type (str):
|
||||
Optional. The API type to use for the language model.
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
llm_config (Mapping[str, Any]):
|
||||
Optional. Configuration dictionary for the language model.
|
||||
If provided, this configuration will be used directly.
|
||||
Otherwise, a default `llm_config` will be created using `model`
|
||||
and `api_type`. This `llm_config` is used as the default
|
||||
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
|
||||
contains a "llm_config", in which case the provided
|
||||
`runnable_kwargs["llm_config"]` will be used.
|
||||
system_instruction (str):
|
||||
Optional. The system instruction for the agent.
|
||||
This instruction is used as the default
|
||||
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
|
||||
already contains a "system_message", in which case the provided
|
||||
`runnable_kwargs["system_message"]` will be used.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the runnable. Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
|
||||
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
|
||||
Other `human_input_mode` values will trigger a warning.
|
||||
runnable_builder (Callable[..., "ConversableAgent"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent.
|
||||
If not provided, a default runnable builder will be used.
|
||||
tools (Sequence[Callable[..., Any]]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a AG2 tool . Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
# Set up llm config.
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._model_name = model or "gemini-1.0-pro-001"
|
||||
self._api_type = api_type or "google"
|
||||
self._llm_config = llm_config or {
|
||||
"config_list": [
|
||||
{
|
||||
"project_id": self._project,
|
||||
"location": self._location,
|
||||
"model": self._model_name,
|
||||
"api_type": self._api_type,
|
||||
}
|
||||
]
|
||||
}
|
||||
self._system_instruction = system_instruction
|
||||
self._runnable_name = runnable_name
|
||||
self._runnable_kwargs = _prepare_runnable_kwargs(
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
llm_config=self._llm_config,
|
||||
system_instruction=self._system_instruction,
|
||||
runnable_name=self._runnable_name,
|
||||
)
|
||||
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
self._ag2_tool_objects = []
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the runnable, binds the runnable with tools.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_autogen = _utils._import_openinference_autogen_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_autogen,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple AG2Agents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_autogen.AutogenInstrumentor()
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
|
||||
# Set up tools.
|
||||
if self._tools and not self._ag2_tool_objects:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
autogen_tools = _utils._import_autogen_tools_or_warn()
|
||||
if autogen_tools:
|
||||
for tool in self._tools:
|
||||
self._ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
|
||||
|
||||
# Set up runnable.
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
**self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "AG2Agent":
|
||||
"""Returns a clone of the AG2Agent."""
|
||||
import copy
|
||||
|
||||
return AG2Agent(
|
||||
model=self._model_name,
|
||||
api_type=self._api_type,
|
||||
llm_config=copy.deepcopy(self._llm_config),
|
||||
system_instruction=self._system_instruction,
|
||||
runnable_name=self._runnable_name,
|
||||
tools=copy.deepcopy(self._tools),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
max_turns: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
max_turns (int):
|
||||
Optional. The maximum number of turns to run the agent for.
|
||||
If not provided, the agent will run indefinitely.
|
||||
If `max_turns` is a `float`, it will be converted to `int`
|
||||
through rounding.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.run()` method of the corresponding runnable.
|
||||
Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
|
||||
The `user_input` parameter defaults to `False`, and should not
|
||||
be passed through `kwargs`.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = {"content": input}
|
||||
|
||||
if max_turns and isinstance(max_turns, float):
|
||||
# Supporting auto-conversion float to int.
|
||||
max_turns = round(max_turns)
|
||||
|
||||
if "user_input" in kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"The `user_input` parameter should not be passed through"
|
||||
"kwargs. The `user_input` defaults to `False`."
|
||||
)
|
||||
kwargs.pop("user_input")
|
||||
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
# `.run()` will return a `ChatResult` object, which is a dataclass.
|
||||
# We need to convert it to a JSON-serializable object.
|
||||
# More details of `ChatResult` can be found in
|
||||
# https://docs.ag2.ai/docs/api-reference/autogen/ChatResult.
|
||||
return _utils.dataclass_to_dict(
|
||||
self._runnable.run(
|
||||
input,
|
||||
user_input=False,
|
||||
tools=self._ag2_tool_objects,
|
||||
max_turns=max_turns,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,643 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
GetSessionHistoryCallable = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
|
||||
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
|
||||
runnable_kwargs = {
|
||||
# input_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input.
|
||||
"input_messages_key": "input",
|
||||
# output_messages_key (str): Must be specified if the underlying
|
||||
# agent returns a dict as output.
|
||||
"output_messages_key": "output",
|
||||
}
|
||||
if has_history:
|
||||
# history_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input and a separate key for historical
|
||||
# messages.
|
||||
runnable_kwargs["history_messages_key"] = "history"
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_output_parser():
|
||||
try:
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.output_parsers.openai_tools import (
|
||||
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
|
||||
)
|
||||
|
||||
return ToolsAgentOutputParser()
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
# The prompt template and runnable_kwargs needs to be customized depending
|
||||
# on whether the user intends for the agent to have history. The way the
|
||||
# user would reflect that is by setting chat_history (which defaults to
|
||||
# None).
|
||||
has_history: bool = chat_history is not None
|
||||
prompt = prompt or _default_prompt(
|
||||
has_history=has_history,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
output_parser = output_parser or _default_output_parser()
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
agent_executor_kwargs = agent_executor_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
agent_executor = AgentExecutor(
|
||||
agent=prompt | model | output_parser,
|
||||
tools=[
|
||||
tool
|
||||
if isinstance(tool, lc_tools.BaseTool)
|
||||
else StructuredTool.from_function(tool)
|
||||
for tool in tools
|
||||
if isinstance(tool, (Callable, lc_tools.BaseTool))
|
||||
],
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
if has_history:
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
return RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
return agent_executor
|
||||
|
||||
|
||||
def _default_prompt(
|
||||
has_history: bool,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import prompts
|
||||
|
||||
try:
|
||||
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages as format_to_tool_messages,
|
||||
)
|
||||
|
||||
system_instructions = []
|
||||
if system_instruction:
|
||||
system_instructions = [("system", system_instruction)]
|
||||
|
||||
if has_history:
|
||||
return {
|
||||
"history": lambda x: x["history"],
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
prompts.MessagesPlaceholder(variable_name="history"),
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LangchainAgent:
|
||||
"""A Langchain Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langchain
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable] = None,
|
||||
runnable_builder: Optional[Callable] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LangchainAgent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
|
||||
```
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
tools=tools,
|
||||
output_parser=output_parser,
|
||||
chat_history=chat_history,
|
||||
agent_executor_kwargs=agent_executor_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langchain import agents
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
agent_executor = agents.AgentExecutor(
|
||||
agent=prompt | llm_with_tools | output_parser,
|
||||
tools=tools,
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
runnable = RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
system_instruction (str):
|
||||
Optional. The system instruction to use for the agent. This
|
||||
argument should not be specified if `prompt` is specified.
|
||||
prompt (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The prompt template for the model. Defaults to a
|
||||
ChatPromptTemplate.
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
output_parser (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The output parser for the model. Defaults to an
|
||||
output parser that works with Gemini function-calling.
|
||||
chat_history (langchain_core.runnables.history.GetSessionHistoryCallable):
|
||||
Optional. Callable that returns a new BaseChatMessageHistory.
|
||||
Defaults to None, i.e. chat_history is not preserved.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
agent_executor_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.agents.AgentExecutor. An example would be
|
||||
```
|
||||
{
|
||||
# Whether to return the agent's trajectory of intermediate
|
||||
# steps at the end in addition to the final output.
|
||||
"return_intermediate_steps": False,
|
||||
# The maximum number of steps to take before ending the
|
||||
# execution loop.
|
||||
"max_iterations": 15,
|
||||
# The method to use for early stopping if the agent never
|
||||
# returns `AgentFinish`. Either 'force' or 'generate'.
|
||||
"early_stopping_method": "force",
|
||||
# How to handle errors raised by the agent's output parser.
|
||||
# Defaults to `False`, which raises the error.
|
||||
"handle_parsing_errors": False,
|
||||
}
|
||||
```
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
model_builder (Callable):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_builder (Callable):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `prompt` and `system_instruction` are specified.
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
if prompt and system_instruction:
|
||||
raise ValueError(
|
||||
"Only one of `prompt` or `system_instruction` should be specified. "
|
||||
"Consider incorporating the system instruction into the prompt "
|
||||
"rather than passing it separately as an argument."
|
||||
)
|
||||
self._model_name = model
|
||||
self._system_instruction = system_instruction
|
||||
self._prompt = prompt
|
||||
self._output_parser = output_parser
|
||||
self._chat_history = chat_history
|
||||
self._model_kwargs = model_kwargs
|
||||
self._model_tool_kwargs = model_tool_kwargs
|
||||
self._agent_executor_kwargs = agent_executor_kwargs
|
||||
self._runnable_kwargs = runnable_kwargs
|
||||
self._model = None
|
||||
self._model_builder = model_builder
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
prompt=self._prompt,
|
||||
model=self._model,
|
||||
tools=self._tools,
|
||||
system_instruction=self._system_instruction,
|
||||
output_parser=self._output_parser,
|
||||
chat_history=self._chat_history,
|
||||
model_tool_kwargs=self._model_tool_kwargs,
|
||||
agent_executor_kwargs=self._agent_executor_kwargs,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LangchainAgent":
|
||||
"""Returns a clone of the LangchainAgent."""
|
||||
import copy
|
||||
|
||||
return LangchainAgent(
|
||||
model=self._model_name,
|
||||
system_instruction=self._system_instruction,
|
||||
prompt=copy.deepcopy(self._prompt),
|
||||
tools=copy.deepcopy(self._tools),
|
||||
output_parser=copy.deepcopy(self._output_parser),
|
||||
chat_history=copy.deepcopy(self._chat_history),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
|
||||
agent_executor_kwargs=copy.deepcopy(self._agent_executor_kwargs),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._runnable.invoke(input=input, config=config, **kwargs)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
@@ -0,0 +1,658 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
try:
|
||||
from langgraph_checkpoint.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
try:
|
||||
from langgraph.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
BaseCheckpointSaver = Any
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
"""Default callable for building a language model.
|
||||
|
||||
Args:
|
||||
model_name (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
project (str):
|
||||
Required. The Google Cloud project ID.
|
||||
location (str):
|
||||
Required. The Google Cloud location.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI.
|
||||
|
||||
Returns:
|
||||
BaseLanguageModel: The language model.
|
||||
"""
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
checkpointer: Optional[Any] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
"""Default callable for building a runnable.
|
||||
|
||||
Args:
|
||||
model (BaseLanguageModel):
|
||||
Required. The language model.
|
||||
tools (Optional[Sequence[_ToolLike]]):
|
||||
Optional. The tools for the agent to be able to use.
|
||||
checkpointer (Optional[Checkpointer]):
|
||||
Optional. The checkpointer for the agent.
|
||||
model_tool_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments when binding tools to the model.
|
||||
runnable_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments for the runnable.
|
||||
|
||||
Returns:
|
||||
RunnableSerializable: The runnable.
|
||||
"""
|
||||
from langgraph import prebuilt as langgraph_prebuilt
|
||||
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or {}
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
if checkpointer:
|
||||
if "checkpointer" in runnable_kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"checkpointer is being specified in both checkpointer_builder "
|
||||
"and runnable_kwargs. Please specify it in only one of them. "
|
||||
"Overriding the checkpointer in runnable_kwargs."
|
||||
)
|
||||
runnable_kwargs["checkpointer"] = checkpointer
|
||||
return langgraph_prebuilt.create_react_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
|
||||
Args:
|
||||
callable (Callable): The callable to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any parameter is not annotated.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling.
|
||||
|
||||
Args:
|
||||
tools (Sequence[_ToolLike]): The tools to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any tool is a callable with untyped parameters.
|
||||
"""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LanggraphAgent:
|
||||
"""A LangGraph Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langgraph
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "RunnableSerializable"]] = None,
|
||||
checkpointer_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LangGraph Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
model=model,
|
||||
tools=tools,
|
||||
model_tool_kwargs=model_tool_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
runnable = create_react_agent(
|
||||
llm_with_tools,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
By default, no checkpointer is used (i.e. there is no state history). To
|
||||
enable checkpointing, provide a `checkpointer_builder` function that
|
||||
returns a checkpointer instance.
|
||||
|
||||
**Example using Spanner:**
|
||||
```python
|
||||
def checkpointer_builder(instance_id, database_id, project_id, **kwargs):
|
||||
from langchain_google_spanner import SpannerCheckpointSaver
|
||||
|
||||
checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id)
|
||||
with checkpointer.cursor() as cur:
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoints")
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoint_writes")
|
||||
checkpointer.setup()
|
||||
|
||||
return checkpointer
|
||||
```
|
||||
|
||||
**Example using an in-memory checkpointer:**
|
||||
```python
|
||||
def checkpointer_builder(**kwargs):
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
return MemorySaver()
|
||||
```
|
||||
|
||||
The `checkpointer_builder` function will be called with any keyword
|
||||
arguments passed to the agent's constructor. Ensure your
|
||||
`checkpointer_builder` function accepts `**kwargs` to handle these
|
||||
arguments, even if unused.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
model_builder (Callable[..., "BaseLanguageModel"]):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
runnable_builder (Callable[..., "RunnableSerializable"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
checkpointer_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the checkpointer returned by `checkpointer_builder`.
|
||||
checkpointer_builder (Callable[..., "BaseCheckpointSaver"]):
|
||||
Optional. Callable that returns a checkpointer. This can be used
|
||||
for defining the checkpointer of the Agent. Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
self._model_name = model
|
||||
self._model_kwargs = model_kwargs
|
||||
self._model_tool_kwargs = model_tool_kwargs
|
||||
self._runnable_kwargs = runnable_kwargs
|
||||
self._checkpointer_kwargs = checkpointer_kwargs
|
||||
self._model = None
|
||||
self._model_builder = model_builder
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
self._checkpointer_builder = checkpointer_builder
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
self._checkpointer = None
|
||||
if self._checkpointer_builder:
|
||||
checkpointer_kwargs = self._checkpointer_kwargs or {}
|
||||
self._checkpointer = self._checkpointer_builder(
|
||||
**checkpointer_kwargs,
|
||||
)
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
model=self._model,
|
||||
tools=self._tools,
|
||||
checkpointer=self._checkpointer,
|
||||
model_tool_kwargs=self._model_tool_kwargs,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LanggraphAgent":
|
||||
"""Returns a clone of the LanggraphAgent."""
|
||||
import copy
|
||||
|
||||
return LanggraphAgent(
|
||||
model=self._model_name,
|
||||
tools=copy.deepcopy(self._tools),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
checkpointer_kwargs=copy.deepcopy(self._checkpointer_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
runnable_builder=self._runnable_builder,
|
||||
checkpointer_builder=self._checkpointer_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._runnable.invoke(input=input, config=config, **kwargs)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
|
||||
def get_state_history(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Any]:
|
||||
"""Gets the state history of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Yields:
|
||||
Dict[str, Any]: The state history of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for state_snapshot in self._runnable.get_state_history(config=config, **kwargs):
|
||||
yield state_snapshot._asdict()
|
||||
|
||||
def get_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets the current state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The current state of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return self._runnable.get_state(config=config, **kwargs)._asdict()
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Updates the state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The updated state of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return self._runnable.update_state(config=config, **kwargs)
|
||||
|
||||
def register_operations(self) -> Mapping[str, Sequence[str]]:
|
||||
"""Registers the operations of the Agent.
|
||||
|
||||
This mapping defines how different operation modes (e.g., "", "stream")
|
||||
are implemented by specific methods of the Agent. The "default" mode,
|
||||
represented by the empty string ``, is associated with the `query` API,
|
||||
while the "stream" mode is associated with the `stream_query` API.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Sequence[str]]: A mapping of operation modes to a list
|
||||
of method names that implement those operation modes.
|
||||
"""
|
||||
return {
|
||||
"": ["query", "get_state", "update_state"],
|
||||
"stream": ["stream_query", "get_state_history"],
|
||||
}
|
||||
@@ -0,0 +1,553 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from llama_index.core.base.query_pipeline import query
|
||||
from llama_index.core.llms import function_calling
|
||||
from llama_index.core import query_pipeline
|
||||
|
||||
FunctionCallingLLM = function_calling.FunctionCallingLLM
|
||||
QueryComponent = query.QUERY_COMPONENT_TYPE
|
||||
QueryPipeline = query_pipeline.QueryPipeline
|
||||
except ImportError:
|
||||
FunctionCallingLLM = Any
|
||||
QueryComponent = Any
|
||||
QueryPipeline = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "FunctionCallingLLM":
|
||||
"""Creates a default model builder for LlamaIndex."""
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from llama_index.llms import google_genai
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
model = google_genai.GoogleGenAI(
|
||||
model=model_name,
|
||||
vertexai_config={"project": project, "location": location},
|
||||
**model_kwargs,
|
||||
)
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "FunctionCallingLLM",
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["QueryComponent"] = None,
|
||||
retriever: Optional["QueryComponent"] = None,
|
||||
response_synthesizer: Optional["QueryComponent"] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "QueryPipeline":
|
||||
"""Creates a default runnable builder for LlamaIndex."""
|
||||
try:
|
||||
from llama_index.core.query_pipeline import QueryPipeline
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
|
||||
)
|
||||
|
||||
prompt = prompt or _default_prompt(
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
pipeline = QueryPipeline(**runnable_kwargs)
|
||||
pipeline_modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
if retriever:
|
||||
pipeline_modules["retriever"] = retriever
|
||||
if response_synthesizer:
|
||||
pipeline_modules["response_synthesizer"] = response_synthesizer
|
||||
|
||||
pipeline.add_modules(pipeline_modules)
|
||||
pipeline.add_link("prompt", "model")
|
||||
if "retriever" in pipeline_modules:
|
||||
pipeline.add_link("model", "retriever")
|
||||
if "response_synthesizer" in pipeline_modules:
|
||||
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
|
||||
if "retriever" in pipeline_modules:
|
||||
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def _default_prompt(
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> "QueryComponent":
|
||||
"""Creates a default prompt template for LlamaIndex.
|
||||
|
||||
Handles both system instruction and user input.
|
||||
|
||||
Args:
|
||||
system_instruction (str, optional): The system instruction to use.
|
||||
|
||||
Returns:
|
||||
QueryComponent: The LlamaIndex QueryComponent.
|
||||
"""
|
||||
try:
|
||||
from llama_index.core import prompts
|
||||
from llama_index.core.base.llms import types
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
|
||||
)
|
||||
|
||||
# Define a prompt template
|
||||
message_templates = []
|
||||
if system_instruction:
|
||||
message_templates.append(
|
||||
types.ChatMessage(role=types.MessageRole.SYSTEM, content=system_instruction)
|
||||
)
|
||||
# Add user input message
|
||||
message_templates.append(
|
||||
types.ChatMessage(role=types.MessageRole.USER, content="{input}")
|
||||
)
|
||||
|
||||
# Create the prompt template
|
||||
return prompts.ChatPromptTemplate(message_templates=message_templates)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LlamaIndexQueryPipelineAgents in the same
|
||||
environment, it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LlamaIndexQueryPipelineAgent is
|
||||
created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LlamaIndexQueryPipelineAgent:
|
||||
"""A LlamaIndex Query Pipeline Agent.
|
||||
|
||||
This agent uses a query pipeline for LLAIndex, including prompt, model,
|
||||
retrieval and summarization steps. More details can be found in
|
||||
https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["QueryComponent"] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable[..., "FunctionCallingLLM"]] = None,
|
||||
retriever_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
retriever_builder: Optional[Callable[..., "QueryComponent"]] = None,
|
||||
response_synthesizer_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
response_synthesizer_builder: Optional[Callable[..., "QueryComponent"]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "QueryPipeline"]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LlamaIndexQueryPipelineAgent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
# model_builder
|
||||
model = model_builder(model_name, project, location, model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
runnable = runnable_builder(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
retriever=retriever_builder(model, retriever_kwargs),
|
||||
response_synthesizer=response_synthesizer_builder(
|
||||
model, response_synthesizer_kwargs
|
||||
),
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to a
|
||||
query pipeline `Prompt - Model`:
|
||||
```python
|
||||
# Default Model Builder
|
||||
model = google_genai.GoogleGenAI(
|
||||
model=model_name,
|
||||
vertexai_config={
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
},
|
||||
)
|
||||
|
||||
# Default Prompt Builder
|
||||
prompt = prompts.ChatPromptTemplate(
|
||||
message_templates=[
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.USER,
|
||||
content="{input}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Default Runnable Builder
|
||||
runnable = QueryPipeline(
|
||||
modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
pipeline.add_link("prompt", "model")
|
||||
```
|
||||
|
||||
When `system_instruction` is specified, the prompt will be updated to
|
||||
include the system instruction.
|
||||
```python
|
||||
# Updated Prompt Builder
|
||||
prompt = prompts.ChatPromptTemplate(
|
||||
message_templates=[
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.SYSTEM,
|
||||
content=system_instruction,
|
||||
),
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.USER,
|
||||
content="{input}",
|
||||
),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
When all inputs are specified, this corresponds to a query pipeline
|
||||
`Prompt - Model - Retriever - Summarizer`:
|
||||
```python
|
||||
runnable = QueryPipeline(
|
||||
modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"retriever": retriever_builder(retriever_kwargs),
|
||||
"response_synthesizer": response_synthesizer_builder(
|
||||
response_synthesizer_kwargs
|
||||
),
|
||||
},
|
||||
)
|
||||
pipeline.add_link("prompt", "model")
|
||||
pipeline.add_link("model", "retriever")
|
||||
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
|
||||
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
|
||||
```
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
The name of the model (e.g. "gemini-1.0-pro").
|
||||
system_instruction (str):
|
||||
Optional. The system instruction to use for the agent.
|
||||
prompt (llama_index.core.base.query_pipeline.query.QUERY_COMPONENT_TYPE):
|
||||
Optional. The prompt template for the model.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the model constructor of the
|
||||
google_genai.GoogleGenAI. An example of a model_kwargs is:
|
||||
```python
|
||||
{
|
||||
# api_key (string): The API key for the GoogleGenAI model.
|
||||
# The API can also be fetched from the GOOGLE_API_KEY
|
||||
# environment variable. If `vertexai_config` is provided,
|
||||
# the API key is ignored.
|
||||
"api_key": "your_api_key",
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection. If not provided,
|
||||
# the default temperature is 0.1.
|
||||
"temperature": 0.1,
|
||||
# context_window (int): The context window of the model.
|
||||
# If not provided, the default context window is 200000.
|
||||
"context_window": 200000,
|
||||
# max_tokens (int): Token limit determines the maximum
|
||||
# amount of text output from one prompt. If not provided,
|
||||
# the default max_tokens is 256.
|
||||
"max_tokens": 256,
|
||||
# is_function_calling_model (bool): Whether the model is a
|
||||
# function calling model. If not provided, the default
|
||||
# is_function_calling_model is True.
|
||||
"is_function_calling_model": True,
|
||||
}
|
||||
```
|
||||
model_builder (Callable):
|
||||
Optional. Callable that returns a language model.
|
||||
retriever_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the retriever constructor.
|
||||
retriever_builder (Callable):
|
||||
Optional. Callable that returns a retriever object.
|
||||
response_synthesizer_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the response synthesizer constructor.
|
||||
response_synthesizer_builder (Callable):
|
||||
Optional. Callable that returns a response_synthesizer object.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the runnable constructor.
|
||||
runnable_builder (Callable):
|
||||
Optional. Callable that returns a runnable (query pipeline).
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing. Defaults to False.
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._model_name = model
|
||||
self._system_instruction = system_instruction
|
||||
self._prompt = prompt
|
||||
|
||||
self._model = None
|
||||
self._model_kwargs = model_kwargs or {}
|
||||
self._model_builder = model_builder
|
||||
|
||||
self._retriever = None
|
||||
self._retriever_kwargs = retriever_kwargs or {}
|
||||
self._retriever_builder = retriever_builder
|
||||
|
||||
self._response_synthesizer = None
|
||||
self._response_synthesizer_kwargs = response_synthesizer_kwargs or {}
|
||||
self._response_synthesizer_builder = response_synthesizer_builder
|
||||
|
||||
self._runnable = None
|
||||
self._runnable_kwargs = runnable_kwargs or {}
|
||||
self._runnable_builder = runnable_builder
|
||||
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, connects it with the prompt template,
|
||||
retriever and response_synthesizer.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_llama_index = (
|
||||
_utils._import_openinference_llama_index_or_warn()
|
||||
)
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_llama_index,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LlamaIndexQueryPipelineAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_llama_index.LlamaIndexInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
|
||||
if self._retriever_builder:
|
||||
self._retriever = self._retriever_builder(
|
||||
model=self._model,
|
||||
retriever_kwargs=self._retriever_kwargs,
|
||||
)
|
||||
|
||||
if self._response_synthesizer_builder:
|
||||
self._response_synthesizer = self._response_synthesizer_builder(
|
||||
model=self._model,
|
||||
response_synthesizer_kwargs=self._response_synthesizer_kwargs,
|
||||
)
|
||||
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
prompt=self._prompt,
|
||||
model=self._model,
|
||||
system_instruction=self._system_instruction,
|
||||
retriever=self._retriever,
|
||||
response_synthesizer=self._response_synthesizer,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LlamaIndexQueryPipelineAgent":
|
||||
"""Returns a clone of the LlamaIndexQueryPipelineAgent."""
|
||||
import copy
|
||||
|
||||
return LlamaIndexQueryPipelineAgent(
|
||||
model=self._model_name,
|
||||
system_instruction=self._system_instruction,
|
||||
prompt=copy.deepcopy(self._prompt),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
retriever_kwargs=copy.deepcopy(self._retriever_kwargs),
|
||||
retriever_builder=self._retriever_builder,
|
||||
response_synthesizer_kwargs=copy.deepcopy(
|
||||
self._response_synthesizer_kwargs
|
||||
),
|
||||
response_synthesizer_builder=self._response_synthesizer_builder,
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
**kwargs: Any,
|
||||
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
|
||||
if kwargs.get("batch"):
|
||||
nest_asyncio = _utils._import_nest_asyncio_or_warn()
|
||||
nest_asyncio.apply()
|
||||
|
||||
return _utils.to_json_serializable_llama_index_object(
|
||||
self._runnable.run(**input, **kwargs)
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.tokenization._tokenizers import (
|
||||
_get_tokenizer_for_model_preview as get_tokenizer_for_model,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer_for_model"]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for tuning models."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.tuning._tuning import TuningJob
|
||||
|
||||
__all__ = [
|
||||
"TuningJob",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user