structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Vertex Gen AI Evaluation Service Module."""
from vertexai.preview.evaluation import _base
from vertexai.preview.evaluation import autorater_utils
from vertexai.preview.evaluation import eval_task
from vertexai.preview.evaluation import metrics
from vertexai.preview.evaluation import prompt_template
EvalResult = _base.EvalResult
EvalTask = eval_task.EvalTask
PairwiseMetric = metrics.PairwiseMetric
PointwiseMetric = metrics.PointwiseMetric
CustomMetric = metrics.CustomMetric
PromptTemplate = prompt_template.PromptTemplate
PairwiseMetricPromptTemplate = metrics.PairwiseMetricPromptTemplate
PointwiseMetricPromptTemplate = metrics.PointwiseMetricPromptTemplate
MetricPromptTemplateExamples = metrics.MetricPromptTemplateExamples
AutoraterConfig = autorater_utils.AutoraterConfig
CustomOutputConfig = metrics.CustomOutputConfig
RubricBasedMetric = metrics.RubricBasedMetric
RubricGenerationConfig = metrics.RubricGenerationConfig
PredefinedRubricMetrics = metrics.PredefinedRubricMetrics
__all__ = [
"EvalTask",
"EvalResult",
"PairwiseMetric",
"PointwiseMetric",
"CustomMetric",
"PromptTemplate",
"PairwiseMetricPromptTemplate",
"PointwiseMetricPromptTemplate",
"MetricPromptTemplateExamples",
"AutoraterConfig",
"CustomOutputConfig",
"RubricBasedMetric",
"RubricGenerationConfig",
"PredefinedRubricMetrics",
]

View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Base classes for evaluation."""
import dataclasses
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from google.cloud.aiplatform_v1beta1.services import (
evaluation_service as gapic_evaluation_services,
)
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview.evaluation.metrics import (
_base as metrics_base,
)
if TYPE_CHECKING:
import pandas as pd
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
@dataclasses.dataclass
class EvaluationRunConfig:
"""Evaluation Run Configurations.
Attributes:
dataset: The dataset to evaluate.
metrics: The list of metric names, or Metric instances to evaluate.
metric_column_mapping: An optional dictionary column mapping that overrides
the metric prompt template input variable names with mapped the evaluation
dataset column names, used during evaluation. For example, if the
input_variables of the metric prompt template are ["context",
"reference"], the metric_column_mapping can be { "context":
"news_context", "reference": "ground_truth", "response":
"model_1_response" } if the dataset has columns "news_context",
"ground_truth" and "model_1_response".
client: The evaluation service client.
evaluation_service_qps: The custom QPS limit for the evaluation service.
retry_timeout: How long to keep retrying the evaluation requests, in
seconds.
autorater_config: The autorater config for model based evaluation.
"""
dataset: "pd.DataFrame"
metrics: List[Union[str, metrics_base._Metric]]
metric_column_mapping: Dict[str, str]
client: gapic_evaluation_services.EvaluationServiceClient
evaluation_service_qps: float
retry_timeout: float
autorater_config: Optional[AutoraterConfig] = None
def validate_dataset_column(self, column_name: str) -> None:
"""Validates that the column names in the column map are in the dataset.
Args:
column_name: The column name to validate.
Raises:
KeyError: If any of the column names are not in the dataset.
"""
if (
self.metric_column_mapping.get(column_name, column_name)
not in self.dataset.columns
):
raise KeyError(
"Required column"
f" `{self.metric_column_mapping.get(column_name, column_name)}` not"
" found in the evaluation dataset. The columns in the evaluation"
f" dataset are {list(self.dataset.columns)}."
)
@dataclasses.dataclass
class EvalResult:
"""Evaluation result.
Attributes:
summary_metrics: A dictionary of summary evaluation metrics for an
evaluation run.
metrics_table: A pandas.DataFrame table containing evaluation dataset
inputs, predictions, explanations, and metric results per row.
metadata: The metadata for the evaluation run.
"""
summary_metrics: Dict[str, float]
metrics_table: Optional["pd.DataFrame"] = None
metadata: Optional[Dict[str, str]] = None
@dataclasses.dataclass
class AutoraterEvalResult:
"""Evaluation result for autorater evaluation."""
def __init__(
self,
eval_result: Optional[List[Dict[str, Any]]],
eval_dataset_metadata: Optional[Dict[str, Any]],
autorater_config: Optional[AutoraterConfig],
**kwargs,
):
"""Initializes an AutoraterEvalResult.
Args:
eval_result: Evaluation result from an evaluation run.
eval_dataset_metadata: Evaluation dataset metadata.
autorater_config: Autorater configuration.
**kwargs: Additional arguments added to AutoraterEvalResult.
"""
self.eval_result = eval_result
self.eval_dataset_metadata = eval_dataset_metadata
self.autorater_config = autorater_config
self.__dict__.update(kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility functions for all pre-evaluation steps."""
from __future__ import annotations
from concurrent import futures
from typing import Callable, Optional, Set, TYPE_CHECKING, Union, List
from google.cloud.aiplatform import base
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from vertexai import generative_models
from vertexai.preview.evaluation import _base as evaluation_base
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation import multimodal_utils
from vertexai.preview.evaluation import (
prompt_template as prompt_template_base,
)
if TYPE_CHECKING:
import pandas as pd
try:
from tqdm import tqdm
except ImportError:
raise ImportError(
'tqdm is not installed. Please install the SDK using "pip install'
' google-cloud-aiplatform[evaluation]"'
)
_LOGGER = base.Logger(__name__)
_SUCCESSFUL_FINISH_REASONS = [
gapic_content_types.Candidate.FinishReason.STOP,
gapic_content_types.Candidate.FinishReason.MAX_TOKENS,
# Many responses have this finish reason
gapic_content_types.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
]
def _assemble_prompt(
row: "pd.Series",
prompt_template: Union[prompt_template_base.PromptTemplate, str],
) -> str:
"""Assembles the prompt template with the given row data."""
if isinstance(prompt_template, str):
prompt_template = prompt_template_base.PromptTemplate(prompt_template)
_check_variable_columns_exist(row, prompt_template.variables)
return str(
prompt_template.assemble(
**row[list(prompt_template.variables)].astype(str).to_dict()
)
)
def _generate_content_text_response(
model: generative_models.GenerativeModel, prompt: str, max_attempts: int = 3
) -> str:
"""Generates a text response from Gemini model from a text prompt with retries .
Args:
model: The Gemini model instance.
prompt: The prompt to send to the model.
max_attempts: Maximum number of attempts for response generation.
Returns:
The text response from the model.
Raises:
RuntimeError if the prompt or the response for the prompt is blocked for
safety reasons.
"""
for attempt in range(max_attempts):
try:
response = model.generate_content(prompt)
if not response.candidates:
error_message = (
f"The model response was blocked due to"
f" {response._raw_response.prompt_feedback.block_reason.name}.\n"
f"Blocked reason message:"
f" {response._raw_response.prompt_feedback.block_reason_message}.\n"
"The input prompt may be blocked for safety reasons.\n"
f"Prompt: {prompt}.\n"
f"Attempt: {attempt + 1}/{max_attempts}"
)
_LOGGER.warning(error_message)
break
else:
candidate = response.candidates[0]
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
error_message = (
"The model response did not finish"
" successfully.\n"
f"Finish reason: {candidate.finish_reason}.\n"
f"Finish message: {candidate.finish_message}.\n"
f"Safety ratings: {candidate.safety_ratings}.\n"
"Please adjust the model safety_settings, or"
" try a different prompt.\n"
f"Attempt: {attempt + 1}/{max_attempts}"
)
_LOGGER.warning(error_message)
else:
return response.candidates[0].content.parts[0].text
except Exception as e:
error_message = (
f"Failed to generate response candidates from Gemini model"
f" {model._model_name}.\n"
f"Error: {e}.\n"
f"Prompt: {prompt}.\n"
f"Attempt: {attempt + 1}/{max_attempts}"
)
_LOGGER.warning(error_message)
if attempt < max_attempts - 1:
_LOGGER.info(
f"Retrying response generation for prompt: {prompt}, attempt"
f" {attempt + 1}/{max_attempts}..."
)
final_error_message = (
f"Failed to generate response from Gemini model {model._model_name}.\n"
f"Prompt: {prompt}."
)
_LOGGER.error(final_error_message)
return constants.RESPONSE_ERROR
def _generate_responses_from_gemini_model(
model: generative_models.GenerativeModel,
df: "pd.DataFrame",
rubric_generation_prompt_template: Optional[str] = None,
) -> List[str]:
"""Generates responses from Gemini model for the given evaluation dataset.
Args:
model: The Gemini model instance.
df: Evaluation Dataset.
Returns:
The list of model responses.
"""
_LOGGER.info(
f"Generating a total of {df.shape[0]} "
f"responses from Gemini model {model._model_name.split('/')[-1]}."
)
tasks = []
with tqdm(total=len(df)) as pbar:
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
for idx, row in df.iterrows():
if rubric_generation_prompt_template:
input_columns = prompt_template_base.PromptTemplate(
rubric_generation_prompt_template
).variables
if multimodal_utils.is_multimodal_instance(
row[list(input_columns)].to_dict()
):
prompt = multimodal_utils._assemble_multi_modal_prompt(
rubric_generation_prompt_template, row, idx, input_columns
)
else:
prompt = _assemble_prompt(
row, rubric_generation_prompt_template
)
else:
prompt = row[constants.Dataset.PROMPT_COLUMN]
task = executor.submit(
_generate_content_text_response,
prompt=prompt,
model=model,
)
task.add_done_callback(lambda _: pbar.update(1))
tasks.append(task)
responses = [future.result() for future in tasks]
return responses
def _generate_response_from_custom_model_fn(
model_fn: Callable[[str], str], eval_dataset: "pd.DataFrame"
) -> List[str]:
"""Generates responses from a custom model function.
Args:
model_fn: The custom model function.
eval_dataset: Evaluation Dataset.
Returns:
The list of model responses.
"""
max_workers = 5
_LOGGER.info(
f"Generating a total of {eval_dataset.shape[0]} "
"responses from the custom model function."
)
tasks = []
try:
with tqdm(total=len(eval_dataset)) as pbar:
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for _, row in eval_dataset.iterrows():
task = executor.submit(
model_fn, row[constants.Dataset.PROMPT_COLUMN]
)
task.add_done_callback(lambda _: pbar.update(1))
tasks.append(task)
except (ValueError, IndexError) as e:
_LOGGER.warning(f"Failed to generate response from model function: {e}")
responses = [task.result() for task in tasks]
return responses
def populate_eval_dataset_with_model_responses(
responses: List[str],
evaluation_run_config: evaluation_base.EvaluationRunConfig,
is_baseline_model: bool = False,
) -> None:
"""Populates the evaluation dataset with model responses.
Args:
responses: The list of model responses.
evaluation_run_config: Evaluation Run Configurations.
is_baseline_model: Whether the model is a baseline model for
PairwiseMetric.
"""
df = evaluation_run_config.dataset.copy()
if is_baseline_model:
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
else:
evaluation_run_config.dataset = df.assign(response=responses)
_LOGGER.info(
f"All {evaluation_run_config.dataset.shape[0]} responses are successfully"
f" generated from model."
)
def _check_variable_columns_exist(
dataset_row: "pd.Series", variable_names_set: Set[str]
) -> None:
"""Checks if all variable names exist in the dataset columns.
Args:
dataset: The dataset to evaluate.
variable_names_set: A set of variable names.
Raises:
ValueError: If any variable names do not exist in the dataset columns
or the prompt template is invalid.
"""
actual_column_names_set = set(dataset_row.to_dict().keys())
if not variable_names_set.issubset(actual_column_names_set):
missing_columns = variable_names_set - actual_column_names_set
raise ValueError(
"Failed to assemble prompt template: The following column(s) are"
f" missing: {', '.join(missing_columns)}. "
f"Please verify prompt_template variables {variable_names_set} and "
f"evaluation dataset column names {actual_column_names_set}."
)

View File

@@ -0,0 +1,238 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Autorater Utils Class and Functions."""
import logging
import time
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union
from vertexai import generative_models
from vertexai.preview.evaluation import _base as evaluation_base
from vertexai.preview.evaluation import eval_task
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
from vertexai.preview.tuning import sft
from sklearn import metrics
if TYPE_CHECKING:
import pandas as pd
AutoraterConfig = evaluation_base.AutoraterConfig
AutoraterEvalResult = evaluation_base.AutoraterEvalResult
EvalTask = eval_task.EvalTask
PointwiseMetric = pointwise_metric.PointwiseMetric
PairwiseMetric = pairwise_metric.PairwiseMetric
_SCORE = "score"
_METRIC = "metric"
_PAIRWISE_CHOICE = "pairwise_choice"
_HUMAN_RATING = "human_rating"
_HUMAN_PAIRWISE_CHOICE = "human_pairwise_choice"
_ACCURACY_BALANCED = "accuracy_balanced"
_F1_SCORE_BALANCED = "f1_score_balanced"
_CONFUSION_MATRIX = "confusion_matrix"
_CONFUSION_MATRIX_LABELS = "confusion_matrix_labels"
_METRICS_CATEGORY_LIMIT = 10
_NAN = "nan"
_ERROR = "error"
def tune_autorater(
*,
base_model: Union[str, generative_models.GenerativeModel],
train_dataset: str,
validation_dataset: Optional[str] = None,
tuned_model_display_name: Optional[str] = None,
epochs: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
labels: Optional[Dict[str, str]] = None,
time_out_hours: int = 10,
) -> AutoraterConfig:
"""Lora Tune an autorater model.
Args:
base_model: Model name for tuning, e.g., "gemini-1.0-pro-002".
train_dataset: Cloud Storage path to file containing training dataset for
tuning. The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation
dataset for tuning. The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
labels: User-defined metadata to be associated with trained models
time_out_hours: Timeout in hours for tuning job. Default value is 10
hours.
Returns:
A `AutoraterConfig` object with tuned model endpoint.
"""
tune_job = sft.train(
source_model=base_model,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
tuned_model_display_name=tuned_model_display_name,
epochs=epochs,
learning_rate_multiplier=learning_rate_multiplier,
adapter_size=adapter_size,
labels=labels,
)
time_out_seconds = time_out_hours * 60 * 60
while not tune_job.refresh().has_ended and time_out_seconds > 0:
time.sleep(60)
time_out_seconds -= 60
if tune_job.has_succeeded:
return AutoraterConfig(autorater_model=tune_job.tuned_model_endpoint_name)
else:
raise ValueError(
"Failed to tune autorater model. Please check the logs for more details."
)
def _get_evaluation_result(
metric: Union[PointwiseMetric, PairwiseMetric],
autorater_eval_results: List[str],
human_eval_results: List[str],
) -> Dict[str, Any]:
"""Get evaluation result for autorater."""
filtered_autorater_eval_results = []
filtered_human_eval_results = []
for autorater_eval_result, human_eval_result in zip(
autorater_eval_results, human_eval_results
):
# Filter failed pointwise evaluation results.
if autorater_eval_result.lower() == _NAN or human_eval_result.lower() == _NAN:
continue
# Filter failed pairwise evaluation results.
if (
autorater_eval_result.lower() == _ERROR
or human_eval_result.lower() == _ERROR
):
continue
filtered_autorater_eval_results.append(autorater_eval_result)
filtered_human_eval_results.append(human_eval_result)
labels = list(
sorted(set(filtered_autorater_eval_results) | set(filtered_human_eval_results))
)
eval_result = {_METRIC: metric.metric_name}
eval_result[_ACCURACY_BALANCED] = metrics.balanced_accuracy_score(
filtered_human_eval_results, filtered_autorater_eval_results
)
eval_result[_F1_SCORE_BALANCED] = metrics.f1_score(
filtered_human_eval_results,
filtered_autorater_eval_results,
average="weighted",
)
if len(labels) > _METRICS_CATEGORY_LIMIT:
logging.warning(
"Confusion matrix is not provided as the number of"
" rating rubric values %d is greater than the limit %d.",
len(labels),
_METRICS_CATEGORY_LIMIT,
)
else:
eval_result[_CONFUSION_MATRIX] = metrics.confusion_matrix(
filtered_human_eval_results,
filtered_autorater_eval_results,
labels=labels,
)
eval_result[_CONFUSION_MATRIX_LABELS] = labels
return eval_result
def evaluate_autorater(
*,
evaluate_autorater_input: "pd.DataFrame",
eval_metrics: List[Union[PointwiseMetric, PairwiseMetric]],
autorater_config: Optional[AutoraterConfig] = None,
eval_dataset_metadata: Dict[str, Any] = None,
**kwargs,
) -> AutoraterEvalResult:
"""Evaluates the autorater model using human evaluation results.
Args:
evaluate_autorater_input: Autorater evaluation input, including
evaluation results from human evaluation and autorater model.
eval_metrics: List of model based metrics.
autorater_config: Autorater configuration.
eval_dataset_metadata: Evaluation dataset metadata.
**kwargs: Additional arguments added to AutoraterEvalResult.
Returns:
Autorater evalaution result .
"""
eval_result = []
for metric in eval_metrics:
if isinstance(metric, PointwiseMetric):
autorater_score = list(
map(
lambda x: str(float(x)),
list(evaluate_autorater_input[metric.metric_name + "/" + _SCORE]),
)
)
human_score = list(
map(
lambda x: str(float(x)),
list(
evaluate_autorater_input[
metric.metric_name + "/" + _HUMAN_RATING
]
),
)
)
eval_result.append(
_get_evaluation_result(metric, autorater_score, human_score)
)
elif isinstance(metric, PairwiseMetric):
autorater_choice = list(
map(
str,
list(
evaluate_autorater_input[
metric.metric_name + "/" + _PAIRWISE_CHOICE
]
),
)
)
human_choice = list(
map(
str,
list(
evaluate_autorater_input[
metric.metric_name + "/" + _HUMAN_PAIRWISE_CHOICE
]
),
)
)
eval_result.append(
_get_evaluation_result(metric, autorater_choice, human_choice)
)
else:
continue
return AutoraterEvalResult(
eval_result=eval_result,
eval_dataset_metadata=eval_dataset_metadata,
autorater_config=autorater_config,
**kwargs,
)

View File

@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Constants for evaluation."""
import dataclasses
# The number of concurrent workers to use for making model inference and
# evaluation requests.
MAX_WORKERS = 100
RESPONSE_ERROR = "Error"
@dataclasses.dataclass(frozen=True)
class Metric:
"""Namespace for Metrics."""
# Model-based Pointwise Metrics.
COHERENCE = "coherence"
FLUENCY = "fluency"
SAFETY = "safety"
GROUNDEDNESS = "groundedness"
INSTRUCTION_FOLLOWING = "instruction_following"
VERBOSITY = "verbosity"
TEXT_QUALITY = "text_quality"
SUMMARIZATION_QUALITY = "summarization_quality"
QUESTION_ANSWERING_QUALITY = "question_answering_quality"
MULTI_TURN_CHAT_QUALITY = "multi_turn_chat_quality"
MULTI_TURN_SAFETY = "multi_turn_safety"
RUBRIC_BASED_INSTRUCTION_FOLLOWING = "rubric_based_instruction_following"
# Model-based Pairwise Metrics.
PAIRWISE_COHERENCE = "pairwise_coherence"
PAIRWISE_FLUENCY = "pairwise_fluency"
PAIRWISE_SAFETY = "pairwise_safety"
PAIRWISE_GROUNDEDNESS = "pairwise_groundedness"
PAIRWISE_INSTRUCTION_FOLLOWING = "pairwise_instruction_following"
PAIRWISE_VERBOSITY = "pairwise_verbosity"
PAIRWISE_TEXT_QUALITY = "pairwise_text_quality"
PAIRWISE_SUMMARIZATION_QUALITY = "pairwise_summarization_quality"
PAIRWISE_QUESTION_ANSWERING_QUALITY = "pairwise_question_answering_quality"
PAIRWISE_MULTI_TURN_CHAT_QUALITY = "pairwise_multi_turn_chat_quality"
PAIRWISE_MULTI_TURN_SAFETY = "pairwise_multi_turn_safety"
POINTWISE_METRIC = "pointwise_metric"
PAIRWISE_METRIC = "pairwise_metric"
# Automatic Metrics.
EXACT_MATCH = "exact_match"
BLEU = "bleu"
ROUGE = "rouge"
ROUGE_1 = "rouge_1"
ROUGE_2 = "rouge_2"
ROUGE_L = "rouge_l"
ROUGE_L_SUM = "rouge_l_sum"
TOOL_CALL_VALID = "tool_call_valid"
TOOL_NAME_MATCH = "tool_name_match"
TOOL_PARAMETER_KEY_MATCH = "tool_parameter_key_match"
TOOL_PARAMETER_KV_MATCH = "tool_parameter_kv_match"
TRAJECTORY_EXACT_MATCH = "trajectory_exact_match"
TRAJECTORY_IN_ORDER_MATCH = "trajectory_in_order_match"
TRAJECTORY_ANY_ORDER_MATCH = "trajectory_any_order_match"
TRAJECTORY_PRECISION = "trajectory_precision"
TRAJECTORY_RECALL = "trajectory_recall"
TRAJECTORY_SINGLE_TOOL_USE = "trajectory_single_tool_use"
LATENCY = "latency_in_seconds"
FAILURE = "failure"
AUTOMATIC_METRIC_LIST = (
EXACT_MATCH,
BLEU,
ROUGE,
ROUGE_1,
ROUGE_2,
ROUGE_L,
ROUGE_L_SUM,
TOOL_CALL_VALID,
TOOL_NAME_MATCH,
TOOL_PARAMETER_KEY_MATCH,
TOOL_PARAMETER_KV_MATCH,
)
TRAJECTORY_METRIC_LIST = (
TRAJECTORY_EXACT_MATCH,
TRAJECTORY_IN_ORDER_MATCH,
TRAJECTORY_ANY_ORDER_MATCH,
TRAJECTORY_PRECISION,
TRAJECTORY_RECALL,
TRAJECTORY_SINGLE_TOOL_USE,
)
DEFAULT_METRIC_LIST = (
LATENCY,
FAILURE,
)
POINTWISE_METRIC_PROMPT_TEMPLATE_EXAMPLE_LIST = (
COHERENCE,
FLUENCY,
SAFETY,
GROUNDEDNESS,
INSTRUCTION_FOLLOWING,
VERBOSITY,
TEXT_QUALITY,
SUMMARIZATION_QUALITY,
QUESTION_ANSWERING_QUALITY,
MULTI_TURN_CHAT_QUALITY,
MULTI_TURN_SAFETY,
)
PAIRWISE_METRIC_PROMPT_TEMPLATE_EXAMPLE_LIST = (
PAIRWISE_COHERENCE,
PAIRWISE_FLUENCY,
PAIRWISE_SAFETY,
PAIRWISE_GROUNDEDNESS,
PAIRWISE_INSTRUCTION_FOLLOWING,
PAIRWISE_VERBOSITY,
PAIRWISE_TEXT_QUALITY,
PAIRWISE_SUMMARIZATION_QUALITY,
PAIRWISE_QUESTION_ANSWERING_QUALITY,
PAIRWISE_MULTI_TURN_CHAT_QUALITY,
PAIRWISE_MULTI_TURN_SAFETY,
)
@dataclasses.dataclass(frozen=True)
class MetricResult:
ROW_COUNT_KEY = "row_count"
SCORE_KEY = "score"
EXPLANATION_KEY = "explanation"
CUSTOM_OUTPUT_KEY = "custom_output"
RAW_OUTPUT_KEY = "raw_output"
RAW_OUTPUTS_KEY = "raw_outputs"
PAIRWISE_CHOICE_KEY = "pairwise_choice"
IS_UNSAFE_KEY = "is_unsafe"
IS_UNSAFE_PROBABILITY_KEY = "is_unsafe_probability"
VIOLATED_POLICIES_KEY = "violated_policies"
RUBRIC_LEVEL_INSTRUCTION_FOLLOWING_KEY = "per_rubric_result"
# Automatic Metrics.
EXACT_MATCH_RESULTS = "exact_match_results"
BLEU_RESULTS = "bleu_results"
ROUGE_RESULTS = "rouge_results"
TOOL_CALL_VALID_RESULTS = "tool_call_valid_results"
TOOL_NAME_MATCH_RESULTS = "tool_name_match_results"
TOOL_PARAMETER_KEY_MATCH_RESULTS = "tool_parameter_key_match_results"
TOOL_PARAMETER_KV_MATCH_RESULTS = "tool_parameter_kv_match_results"
TRAJECTORY_EXACT_MATCH_RESULTS = "trajectory_exact_match_results"
TRAJECTORY_IN_ORDER_MATCH_RESULTS = "trajectory_in_order_match_results"
TRAJECTORY_ANY_ORDER_MATCH_RESULTS = "trajectory_any_order_match_results"
TRAJECTORY_PRECISION_RESULTS = "trajectory_precision_results"
TRAJECTORY_RECALL_RESULTS = "trajectory_recall_results"
TRAJECTORY_SINGLE_TOOL_USE_RESULTS = "trajectory_single_tool_use_results"
POINTWISE_METRIC_RESULT = "pointwise_metric_result"
PAIRWISE_METRIC_RESULT = "pairwise_metric_result"
RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT = (
"rubric_based_instruction_following_result"
)
AUTOMATIC_METRIC_RESULTS_LIST = (
EXACT_MATCH_RESULTS,
BLEU_RESULTS,
ROUGE_RESULTS,
TOOL_CALL_VALID_RESULTS,
TOOL_NAME_MATCH_RESULTS,
TOOL_PARAMETER_KEY_MATCH_RESULTS,
TOOL_PARAMETER_KV_MATCH_RESULTS,
TRAJECTORY_EXACT_MATCH_RESULTS,
TRAJECTORY_IN_ORDER_MATCH_RESULTS,
TRAJECTORY_ANY_ORDER_MATCH_RESULTS,
TRAJECTORY_PRECISION_RESULTS,
TRAJECTORY_RECALL_RESULTS,
TRAJECTORY_SINGLE_TOOL_USE_RESULTS,
)
@dataclasses.dataclass(frozen=True)
class Dataset:
# Default evaluation dataset schema column names.
MODEL_RESPONSE_COLUMN = "response"
BASELINE_MODEL_RESPONSE_COLUMN = "baseline_model_response"
PROMPT_COLUMN = "prompt"
REFERENCE_COLUMN = "reference"
PREDICTED_TRAJECTORY_COLUMN = "predicted_trajectory"
REFERENCE_TRAJECTORY_COLUMN = "reference_trajectory"
RUBRICS_COLUMN = "rubrics"
@dataclasses.dataclass(frozen=True)
class QuotaLimit:
"""Generative AI on Vertex AI quota limits."""
# Default Evaluation Service QPS limit.
EVAL_SERVICE_QPS = 10

View File

@@ -0,0 +1,630 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Evaluation Task class."""
import logging
from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union
import uuid
import warnings
from google.api_core import exceptions
import vertexai
from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.metadata import metadata
from vertexai import generative_models
from vertexai.preview import reasoning_engines
from vertexai.preview.evaluation import _base as eval_base
from vertexai.preview.evaluation import _evaluation
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation import utils as eval_utils
from vertexai.preview.evaluation.metrics import (
_base as metrics_base,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
from vertexai.preview.evaluation.metrics import (
rubric_based_metric,
)
import numpy as np
if TYPE_CHECKING:
import pandas as pd
from google.colab import sheets
# pylint: disable=g-import-not-at-top
try:
from IPython import display as IPython_display
except ImportError:
IPython_display = None
_LOGGER = base.Logger(__name__)
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
AutoraterConfig = eval_base.AutoraterConfig
EvalResult = eval_base.EvalResult
GenerativeModel = generative_models.GenerativeModel
_RunnableType = Union[reasoning_engines.Queryable, Callable[[str], Dict[str, str]]]
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
class EvalTask:
"""A class representing an EvalTask.
An evaluation task assesses the ability of a Gen AI model, agent or
application to perform a specific task in response to prompts.
Each evaluation task includes an evaluation dataset, which can be a set of
test cases and a set of metrics for assessment. These tasks provide the
framework for running evaluations in a standardized and repeatable way,
allowing for comparative assessment with varying run-specific parameters.
Dataset Details:
Default dataset column names:
* prompt_column_name: "prompt"
* reference_column_name: "reference"
* response_column_name: "response"
* baseline_model_response_column_name: "baseline_model_response"
* rubrics_column_name: "rubrics"
Requirement for different use cases:
* Bring-your-own-response (BYOR): You already have the data that you
want to evaluate stored in the dataset. Response column name can be
customized by providing `response_column_name` parameter, or in the
`metric_column_mapping`. For BYOR pairwise evaluation, the baseline
model response column name can be customized by providing
`baseline_model_response_column_name` parameter, or
in the `metric_column_mapping`. If the `response` column or
`baseline_model_response` column is present while the
corresponding model is specified, an error will be raised.
* Perform model/agent inference without a prompt template: You have a dataset
containing the input prompts to the model/agent and want to perform
inference before evaluation. A column named `prompt` is required
in the evaluation dataset and is used directly as input to the model/agent.
* Perform model/agent inference with a prompt template: You have a dataset
containing the input variables to the prompt template and want to
assemble the prompts for inference. Evaluation dataset
must contain column names corresponding to the variable names in
the prompt template. For example, if prompt template is
"Instruction: {instruction}, context: {context}", the dataset must
contain `instruction` and `context` columns.
Metrics Details:
The supported metrics descriptions, rating rubrics, and the required
input variables can be found on the Vertex AI public documentation page.
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
Usage Examples:
1. To perform bring-your-own-response(BYOR) evaluation, provide the model
responses in the `response` column in the dataset. If a pairwise metric is
used for BYOR evaluation, provide the baseline model responses in the
`baseline_model_response` column.
```
eval_dataset = pd.DataFrame({
"prompt" : [...],
"reference": [...],
"response" : [...],
"baseline_model_response": [...],
})
eval_task = EvalTask(
dataset=eval_dataset,
metrics=[
"bleu",
"rouge_l_sum",
MetricPromptTemplateExamples.Pointwise.FLUENCY,
MetricPromptTemplateExamples.Pairwise.SAFETY
],
experiment="my-experiment",
)
eval_result = eval_task.evaluate(experiment_run_name="eval-experiment-run")
```
2. To perform evaluation with Gemini model inference, specify the `model`
parameter with a `GenerativeModel` instance. The input column name to the
model is `prompt` and must be present in the dataset.
```
eval_dataset = pd.DataFrame({
"reference": [...],
"prompt" : [...],
})
result = EvalTask(
dataset=eval_dataset,
metrics=["exact_match", "bleu", "rouge_1", "rouge_l_sum"],
experiment="my-experiment",
).evaluate(
model=GenerativeModel("gemini-1.5-pro"),
experiment_run_name="gemini-eval-run"
)
```
3. If a `prompt_template` is specified, the `prompt` column is not required.
Prompts can be assembled from the evaluation dataset, and all prompt
template variable names must be present in the dataset columns.
```
eval_dataset = pd.DataFrame({
"context" : [...],
"instruction": [...],
})
result = EvalTask(
dataset=eval_dataset,
metrics=[MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY],
).evaluate(
model=GenerativeModel("gemini-1.5-pro"),
prompt_template="{instruction}. Article: {context}. Summary:",
)
```
4. To perform evaluation with custom model inference, specify the `model`
parameter with a custom inference function. The input column name to the
custom inference function is `prompt` and must be present in the dataset.
```
from openai import OpenAI
client = OpenAI()
def custom_model_fn(input: str) -> str:
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": input}
]
)
return response.choices[0].message.content
eval_dataset = pd.DataFrame({
"prompt" : [...],
"reference": [...],
})
result = EvalTask(
dataset=eval_dataset,
metrics=[MetricPromptTemplateExamples.Pointwise.SAFETY],
experiment="my-experiment",
).evaluate(
model=custom_model_fn,
experiment_run_name="gpt-eval-run"
)
```
5. To perform pairwise metric evaluation with model inference step, specify
the `baseline_model` input to a `PairwiseMetric` instance and the candidate
`model` input to the `EvalTask.evaluate()` function. The input column name
to both models is `prompt` and must be present in the dataset.
```
baseline_model = GenerativeModel("gemini-1.0-pro")
candidate_model = GenerativeModel("gemini-1.5-pro")
pairwise_groundedness = PairwiseMetric(
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
"pairwise_groundedness"
),
baseline_model=baseline_model,
)
eval_dataset = pd.DataFrame({
"prompt" : [...],
})
result = EvalTask(
dataset=eval_dataset,
metrics=[pairwise_groundedness],
experiment="my-pairwise-experiment",
).evaluate(
model=candidate_model,
experiment_run_name="gemini-pairwise-eval-run",
)
```
"""
def __init__(
self,
*,
dataset: Union["pd.DataFrame", str, Dict[str, Any], "sheets.InteractiveSheet"],
metrics: List[
Union[
Literal[
"exact_match",
"bleu",
"rouge_1",
"rouge_2",
"rouge_l",
"rouge_l_sum",
"tool_call_valid",
"tool_name_match",
"tool_parameter_key_match",
"tool_parameter_kv_match",
"trajectory_exact_match",
"trajectory_in_order_match",
"trajectory_any_order_match",
"trajectory_precision",
"trajectory_recall",
"rubric_based_instruction_following",
],
metrics_base.CustomMetric,
metrics_base._AutomaticMetric,
pointwise_metric.PointwiseMetric,
pairwise_metric.PairwiseMetric,
rubric_based_metric.RubricBasedMetric,
]
],
experiment: Optional[str] = None,
metric_column_mapping: Optional[Dict[str, str]] = None,
output_uri_prefix: Optional[str] = "",
autorater_config: Optional[AutoraterConfig] = None,
):
"""Initializes an EvalTask.
Args:
dataset: The dataset to be evaluated.
Supports the following dataset formats:
* pandas.DataFrame: Used directly for evaluation.
* Dict: Converted to a pandas DataFrame before evaluation.
* str: Interpreted as a file path or URI. Supported formats include:
* Local JSONL or CSV files: Loaded from the local filesystem.
* GCS JSONL or CSV files: Loaded from Google Cloud Storage
(e.g., 'gs://bucket/data.csv').
* BigQuery table URI: Loaded from Google Cloud BigQuery
(e.g., 'bq://project-id.dataset.table_name').
metrics: The list of metric names, or Metric instances to evaluate.
Prompt template is required for PairwiseMetric.
experiment: The name of the experiment to log the evaluations to.
metric_column_mapping: An optional dictionary column mapping that
overrides the metric prompt template input variable names with
mapped the evaluation dataset column names, used during evaluation.
For example, if the input_variables of the metric prompt template
are ["context", "reference"], the metric_column_mapping can be
{
"context": "news_context",
"reference": "ground_truth",
"response": "model_1_response"
}
if the dataset has columns "news_context", "ground_truth" and
"model_1_response".
output_uri_prefix: GCS location to store the metrics_table from
evaluation results.
autorater_config: The autorater config for model based evaluation.
If autorater config is specified on a metric, it will override the
autorater config specified here.
"""
self._dataset = eval_utils.load_dataset(dataset)
self._metrics = metrics
self._experiment = experiment
self._metric_column_mapping = eval_utils.initialize_metric_column_mapping(
metric_column_mapping, self._dataset
)
self.output_uri_prefix = output_uri_prefix
self._autorater_config = autorater_config
@property
def dataset(self) -> "pd.DataFrame":
"""Returns evaluation dataset."""
return self._dataset
@property
def metrics(self) -> List[Union[str, metrics_base.CustomMetric]]:
"""Returns metrics."""
return self._metrics
@property
def autorater_config(self) -> Optional[AutoraterConfig]:
"""Returns autorater config."""
return self._autorater_config
@property
def experiment(self) -> Optional[str]:
"""Returns experiment name."""
return self._experiment
def _evaluate_with_experiment(
self,
model: Optional[_ModelType] = None,
runnable: Optional[_RunnableType] = None,
prompt_template: Optional[str] = None,
experiment_run_name: Optional[str] = None,
evaluation_service_qps: Optional[float] = None,
retry_timeout: float = 120.0,
output_file_name: Optional[str] = None,
) -> EvalResult:
"""Runs an evaluation for the EvalTask with an experiment.
Args:
model: A GenerativeModel instance or a custom model function to generate
responses to evaluate. If not provided, the evaluation is computed with
the `response` column in the `dataset`.
runnable: The runnable to generate responses to evaluate. If not provided,
the evaluation is computed with the `response` and/or `predicted_trajectory`
column in the `dataset`.
prompt_template: The prompt template to use for the evaluation. If not
set, the prompt template that was used to create the EvalTask will be
used.
experiment_run_name: The name of the experiment run to log the evaluation
to if an experiment is set for this EvalTask. If not provided, a random
unique experiment run name is used.
evaluation_service_qps: The custom QPS limit for the evaluation service.
retry_timeout: How long to keep retrying the evaluation requests for
the whole evaluation dataset, in seconds.
output_path: The file name with csv suffix to store the output
metrics_table to be tracked in the experiment run.
Returns:
The evaluation result.
"""
self._validate_experiment_run()
with vertexai.preview.start_run(experiment_run_name):
self._log_eval_experiment_param(
model=model,
runnable=runnable,
prompt_template=prompt_template,
output_file_name=output_file_name,
)
eval_result = _evaluation.evaluate(
dataset=self._dataset,
metrics=self._metrics,
model=model,
runnable=runnable,
prompt_template=prompt_template,
metric_column_mapping=self._metric_column_mapping,
evaluation_service_qps=evaluation_service_qps,
retry_timeout=retry_timeout,
autorater_config=self._autorater_config,
)
eval_result.summary_metrics = {
k: ("NaN" if isinstance(v, float) and np.isnan(v) else v)
for k, v in eval_result.summary_metrics.items()
}
eval_result.metadata = {
"experiment": self._experiment,
"experiment_run": experiment_run_name,
}
try:
vertexai.preview.log_metrics(eval_result.summary_metrics)
except (TypeError, exceptions.InvalidArgument) as e:
_LOGGER.warning(f"Experiment metrics logging failed: {str(e)}")
return eval_result
def evaluate(
self,
*,
model: Optional[_ModelType] = None,
runnable: Optional[_RunnableType] = None,
prompt_template: Optional[str] = None,
experiment_run_name: Optional[str] = None,
response_column_name: Optional[str] = None,
baseline_model_response_column_name: Optional[str] = None,
evaluation_service_qps: Optional[float] = None,
retry_timeout: float = 120.0,
output_file_name: Optional[str] = "",
) -> EvalResult:
"""Runs an evaluation for the EvalTask.
Args:
model: A GenerativeModel instance or a custom model function to generate
responses to evaluate. If not provided, the evaluation can be performed
in the bring-your-own-response (BYOR) mode.
runnable: The runnable to generate responses to evaluate. If not provided,
the evaluation is computed with the `response` and/or `predicted_trajectory`
column in the `dataset`.
prompt_template: The prompt template to use for the evaluation. If not
set, the prompt template that was used to create the EvalTask will be
used.
experiment_run_name: The name of the experiment run to log the evaluation
to if an experiment is set for this EvalTask. If not provided, a random
unique experiment run name is used.
response_column_name: The column name of model response in the dataset. If
provided, this will override the `metric_column_mapping` of the `EvalTask`.
baseline_model_response_column_name: The column name of baseline model
response in the dataset for pairwise metrics. If provided, this will
override the `metric_column_mapping` of the `EvalTask`
evaluation_service_qps: The custom QPS limit for the evaluation service.
retry_timeout: How long to keep retrying the evaluation requests for
the whole evaluation dataset, in seconds.
output_file_name: The file name with csv suffix to store the output
metrics_table.
Returns:
The evaluation result.
"""
global_experiment_name = (
metadata._experiment_tracker.experiment_name
) # pylint: disable=protected-access
if experiment_run_name and not self._experiment and not global_experiment_name:
raise ValueError(
"Experiment is not set. Please initialize EvalTask with an"
" experiment, or initialize a global experiment with "
"`vertexai.init(experiment='experiment_name')`for logging this"
" evaluation run."
)
self._verify_and_set_response_column_name(
response_column_name=response_column_name,
metric_column_mapping_key=constants.Dataset.MODEL_RESPONSE_COLUMN,
)
self._verify_and_set_response_column_name(
response_column_name=baseline_model_response_column_name,
metric_column_mapping_key=constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
)
if self.output_uri_prefix and not output_file_name:
output_file_name = f"eval_results_{utils.timestamped_unique_name()}.csv"
experiment_run_name = experiment_run_name or f"{uuid.uuid4()}"
if self._experiment and global_experiment_name:
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
experiment=self._experiment, backing_tensorboard=False
)
eval_result = self._evaluate_with_experiment(
model=model,
runnable=runnable,
prompt_template=prompt_template,
experiment_run_name=experiment_run_name,
evaluation_service_qps=evaluation_service_qps,
retry_timeout=retry_timeout,
output_file_name=output_file_name,
)
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
experiment=global_experiment_name, backing_tensorboard=False
)
elif self._experiment and not global_experiment_name:
metadata._experiment_tracker.set_experiment( # pylint: disable=protected-access
experiment=self._experiment, backing_tensorboard=False
)
eval_result = self._evaluate_with_experiment(
model=model,
runnable=runnable,
prompt_template=prompt_template,
experiment_run_name=experiment_run_name,
evaluation_service_qps=evaluation_service_qps,
retry_timeout=retry_timeout,
output_file_name=output_file_name,
)
metadata._experiment_tracker.reset() # pylint: disable=protected-access
elif not self._experiment and global_experiment_name:
eval_result = self._evaluate_with_experiment(
model=model,
runnable=runnable,
prompt_template=prompt_template,
experiment_run_name=experiment_run_name,
evaluation_service_qps=evaluation_service_qps,
retry_timeout=retry_timeout,
output_file_name=output_file_name,
)
else:
eval_result = _evaluation.evaluate(
dataset=self._dataset,
metrics=self._metrics,
model=model,
runnable=runnable,
prompt_template=prompt_template,
metric_column_mapping=self._metric_column_mapping,
evaluation_service_qps=evaluation_service_qps,
retry_timeout=retry_timeout,
autorater_config=self._autorater_config,
)
eval_utils.upload_evaluation_results(
eval_result, self.output_uri_prefix, output_file_name
)
return eval_result
def _validate_experiment_run(self) -> None:
"""Checks if an experiment run already exists."""
if (
metadata._experiment_tracker.experiment_run
): # pylint: disable=protected-access
raise ValueError(
"Experiment run already exists. Please specify the name of the"
" experiment run to assign current session within this evaluation."
)
def _log_eval_experiment_param(
self,
model: _ModelType = None,
runnable: _RunnableType = None,
prompt_template: Optional[str] = None,
output_file_name: Optional[str] = None,
) -> None:
"""Logs variable input parameters of an evaluation to an experiment run."""
eval_metadata = {}
if prompt_template is not None:
eval_metadata.update({"prompt_template": prompt_template})
if model:
if isinstance(model, GenerativeModel):
eval_metadata.update(
{
"model_name": model._model_name, # pylint: disable=protected-access
}
)
if (
model._generation_config
and isinstance( # pylint: disable=protected-access
model._generation_config,
dict, # pylint: disable=protected-access
)
):
eval_metadata.update(
**model._generation_config
) # pylint: disable=protected-access
if model._safety_settings and isinstance(
model._safety_settings, dict
): # pylint: disable=protected-access
safety_settings = (
model._safety_settings
) # pylint: disable=protected-access
safety_settings_as_str = {
category.name: threshold.name
for category, threshold in safety_settings.items()
}
eval_metadata.update(safety_settings_as_str)
if runnable:
if isinstance(runnable, reasoning_engines.LangchainAgent):
eval_metadata.update(
{
"model_name": runnable._model_name,
"tools": runnable._tools,
} # pylint: disable=protected-access
)
if self.output_uri_prefix and output_file_name:
eval_metadata.update(
{"output_file": self.output_uri_prefix + "/" + output_file_name}
)
if eval_metadata:
_LOGGER.info(
f"Logging Eval experiment evaluation metadata: {eval_metadata}"
)
try:
vertexai.preview.log_params(eval_metadata)
except (ValueError, TypeError) as e:
_LOGGER.warning(
f"Experiment evaluation metadata logging failed: {str(e)}"
)
def _verify_and_set_response_column_name(
self, response_column_name: str, metric_column_mapping_key: str
) -> None:
"""Verifies and sets the model response column names."""
if response_column_name:
if response_column_name in self._dataset.columns:
self._metric_column_mapping[
metric_column_mapping_key
] = response_column_name
else:
raise ValueError(
f"(Baseline) Model response column {response_column_name} is not"
" found in the dataset."
)
def display_runs(self):
"""Displays experiment runs associated with this EvalTask."""
if not self._experiment:
raise ValueError("Experiment is not set.")
elif IPython_display:
IPython_display.display(
vertexai.preview.get_experiment_df(self._experiment)
)

View File

@@ -0,0 +1,324 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility functions for metrics."""
import datetime
import io
from typing import Any, Dict, List, Optional, Union, Callable
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview.evaluation import metrics
from vertexai.preview.evaluation import prompt_template
from vertexai.preview.evaluation import utils
from vertexai.preview.evaluation.metrics import _schema
from vertexai.preview.evaluation.metrics import (
custom_output_config,
)
from vertexai import generative_models
from vertexai.generative_models import _generative_models
import jsonschema
import ruamel.yaml
from ruamel.yaml import scalarstring
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
GenerativeModel = generative_models.GenerativeModel
CustomOutputConfig = custom_output_config.CustomOutputConfig
PairwiseMetric = metrics.PairwiseMetric
PointwiseMetric = metrics.PointwiseMetric
RubricBasedMetric = metrics.RubricBasedMetric
RubricGenerationConfig = metrics.RubricGenerationConfig
PromptTemplate = prompt_template.PromptTemplate
# Initialize schema validator.
_schema = ruamel.yaml.YAML(typ="safe").load(_schema.AUTORATER_METRIC_SCHEMA)
_schema_validator = jsonschema.Draft202012Validator(schema=_schema)
def dump(
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
file_path: str,
version: Optional[str] = None,
):
"""Dumps a metric object to a YAML file.
Args:
metric: The metric to be dumped to a file.
file_path: The path to the file. Local and GCS files are supported.
version: Optional. The version of the metric. Defaults to the timestamp
when the metric file is created.
"""
yaml_data = dumps(metric, version)
if file_path.startswith(utils._GCS_PREFIX):
utils._upload_string_to_gcs(file_path, yaml_data)
else:
with open(file_path, "w") as f:
f.write(yaml_data)
def dumps(
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
version: Optional[str] = None,
) -> str:
"""Dumps a metric object to YAML data.
Args:
metric: The metric to be dumped to YAML data.
version: Optional. The version of the metric. Defaults to the timestamp
when the metric file is created.
Returns:
The YAML data of the metric.
"""
steps = []
metric_name = None
if isinstance(metric, PointwiseMetric) or isinstance(metric, PairwiseMetric):
metric_name = metric.metric_name
steps.append(_dump_metric(metric))
elif isinstance(metric, RubricBasedMetric):
metric_name = metric.critique_metric.metric_name
steps.append(_dump_rubric(metric.generation_config))
steps.append(_dump_metric(metric.critique_metric))
metadata = {
"name": metric_name,
"version": (
datetime.datetime.now().strftime("%Y%m%d%H%M%S")
if version is None
else version
),
"required_inputs": _parse_required_inputs(metric),
}
metric_config = {
"metadata": metadata,
"steps": steps,
}
yaml = ruamel.yaml.YAML()
yaml.indent(sequence=4, offset=2)
with io.StringIO() as s:
yaml.dump(metric_config, s)
return s.getvalue()
def _dump_metric(metric: Union[PointwiseMetric, PairwiseMetric]) -> Dict[str, Any]:
"""Dumps a metric object to autorater metric schema."""
output_type = None
if metric.custom_output_config and metric.custom_output_config.return_raw_output:
output_type = "raw"
step = {
"type": (
"pairwise_metric"
if isinstance(metric, PairwiseMetric)
else "pointwise_metric"
),
"prompt": {
"template": scalarstring.preserve_literal(metric.metric_prompt_template),
},
}
if metric.system_instruction:
step["prompt"]["system_instruction"] = metric.system_instruction
if output_type:
step["output"] = {
"type": output_type,
}
if metric.autorater_config:
step["model"] = {
"model_name_or_endpoint": (metric.autorater_config.autorater_model),
}
options = {}
if metric.autorater_config.flip_enabled:
options["flip_enabled"] = metric.autorater_config.flip_enabled
if metric.autorater_config.sampling_count:
options["sample_count"] = metric.autorater_config.sampling_count
if options:
step["options"] = options
return step
def _dump_rubric(generation_config: RubricGenerationConfig) -> Dict[str, Any]:
"""Dumps a rubric generation config to autorater metric schema."""
# TODO: b/396217889 - add support for custom output.
step = {
"type": "rubric",
"prompt": {
"template": scalarstring.preserve_literal(
generation_config.prompt_template
),
},
}
if generation_config.model and isinstance(generation_config.model, GenerativeModel):
step["model"] = {
"model_name_or_endpoint": generation_config.model._model_name,
}
return step
def _parse_required_inputs(
metric: Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric],
) -> List[str]:
"""Parses required inputs from a metric object."""
if isinstance(metric, PointwiseMetric) or isinstance(metric, PairwiseMetric):
return list(PromptTemplate(metric.metric_prompt_template).variables)
elif isinstance(metric, RubricBasedMetric):
met = PromptTemplate(metric.critique_metric.metric_prompt_template).variables
gen = PromptTemplate(metric.generation_config.prompt_template).variables
return list(met.union(gen))
else:
raise ValueError(f"Unsupported metric type: {type(metric)}")
def load(
file_path: str,
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
"""Loads a metric object from a YAML file.
Args:
file_path: Path to the file containing the autorater metric configuration.
Local and GCS files are supported.
baseline_model: Optional. The baseline model to use for pairwise metrics.
Returns:
The metric object loaded from the file.
"""
if file_path.startswith(utils._GCS_PREFIX):
file_contents = utils._read_gcs_file_contents(file_path)
return loads(file_contents, baseline_model)
with open(file_path, "r") as f:
return loads(f.read(), baseline_model)
def loads(
yaml_data: str,
baseline_model: Optional[Union[GenerativeModel, Callable[[str], str]]] = None,
) -> Union[PointwiseMetric, PairwiseMetric, RubricBasedMetric]:
"""Loads a metric object from YAML data.
Args:
yaml_data: YAML data containing the autorater metric configuration.
baseline_model: Optional. The baseline model to use for pairwise metrics.
Returns:
The metric object loaded from the YAML data.
"""
yaml = ruamel.yaml.YAML(typ="safe")
yaml_obj = yaml.load(yaml_data)
try:
_schema_validator.validate(yaml_obj)
except jsonschema.exceptions.ValidationError as e:
raise ValueError(
f"Invalid autorater metric config: {e.message} for {e.path.pop()}"
) from e
metadata = yaml_obj["metadata"]
steps = yaml_obj["steps"]
required_inputs = set(metadata["required_inputs"])
metric = None
rubric = None
for step in steps:
_validate_template(step["prompt"]["template"], required_inputs)
model_name = None
flip = None
sampling = None
if "model" in step:
model_name = _parse_model_name(step["model"]["model_name_or_endpoint"])
if "options" in step:
flip = step["options"].get("flip_enabled", False)
sampling = step["options"].get("sample_count", 1)
autorater = None
if model_name:
autorater = AutoraterConfig(
autorater_model=model_name,
flip_enabled=flip,
sampling_count=sampling,
)
system_instruction = step["prompt"].get("system_instruction")
custom_output = None
if "output" in step and step["output"]["type"] == "raw":
custom_output = CustomOutputConfig(return_raw_output=True)
if step["type"] == "pointwise_metric":
if metric is not None:
raise ValueError("Only one metric step is supported.")
if baseline_model:
raise ValueError("Baseline model provided for pointwise metric.")
metric = PointwiseMetric(
metric=metadata["name"],
metric_prompt_template=step["prompt"]["template"],
system_instruction=system_instruction,
autorater_config=autorater,
custom_output_config=custom_output,
)
elif step["type"] == "pairwise_metric":
if metric is not None:
raise ValueError("Only one metric step is supported.")
metric = PairwiseMetric(
metric=metadata["name"],
metric_prompt_template=step["prompt"]["template"],
system_instruction=system_instruction,
baseline_model=baseline_model,
autorater_config=autorater,
custom_output_config=custom_output,
)
elif step["type"] == "rubric":
if rubric is not None:
raise ValueError("Only one rubric step is supported.")
model = None
if model_name:
model = generative_models.GenerativeModel(model_name=model_name)
rubric = RubricGenerationConfig(
prompt_template=step["prompt"]["template"],
model=model,
)
if metric is None:
raise ValueError("A metric step must be provided.")
if rubric is not None:
return RubricBasedMetric(
generation_config=rubric,
critique_metric=metric,
)
return metric
def _parse_model_name(model_name_or_endpoint: str) -> str:
"""Parses model name or endpoint.
Args:
model_name_or_endpoint: Model Garden model name or tuned model endpoint
resource name can be provided.
Returns:
The model resource name.
"""
project = initializer.global_config.project
location = initializer.global_config.location
model_name = _generative_models._reconcile_model_name(
model_name_or_endpoint, project, location
)
return _generative_models._get_resource_name_from_model_name(
model_name, project, location
)
def _validate_template(template: str, required_inputs: List[str]) -> None:
"""Validates the template contains only required inputs."""
placeholders = PromptTemplate(template).variables
if not placeholders.issubset(required_inputs):
raise ValueError(
"Template contains placeholders that are not in required inputs:"
f" {placeholders - required_inputs}"
)

View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Evaluation Metrics Module."""
from vertexai.preview.evaluation.metrics import _base
from vertexai.preview.evaluation.metrics import _rouge
from vertexai.preview.evaluation.metrics import (
_trajectory_single_tool_use,
)
from vertexai.preview.evaluation.metrics import (
custom_output_config,
)
from vertexai.preview.evaluation.metrics import (
metric_prompt_template,
)
from vertexai.preview.evaluation.metrics import (
metric_prompt_template_examples,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
from vertexai.preview.evaluation.metrics import (
predefined_rubric_metrics,
)
from vertexai.preview.evaluation.metrics import (
rubric_based_metric,
)
PairwiseMetric = pairwise_metric.PairwiseMetric
PointwiseMetric = pointwise_metric.PointwiseMetric
CustomMetric = _base.CustomMetric
PairwiseMetricPromptTemplate = metric_prompt_template.PairwiseMetricPromptTemplate
PointwiseMetricPromptTemplate = metric_prompt_template.PointwiseMetricPromptTemplate
MetricPromptTemplateExamples = (
metric_prompt_template_examples.MetricPromptTemplateExamples
)
Rouge = _rouge.Rouge
TrajectorySingleToolUse = _trajectory_single_tool_use.TrajectorySingleToolUse
CustomOutputConfig = custom_output_config.CustomOutputConfig
RubricBasedMetric = rubric_based_metric.RubricBasedMetric
RubricGenerationConfig = _base.RubricGenerationConfig
PredefinedRubricMetrics = predefined_rubric_metrics.PredefinedRubricMetrics
__all__ = [
"CustomMetric",
"PairwiseMetric",
"PointwiseMetric",
"PairwiseMetricPromptTemplate",
"PointwiseMetricPromptTemplate",
"MetricPromptTemplateExamples",
"Rouge",
"TrajectorySingleToolUse",
"CustomOutputConfig",
"RubricBasedMetric",
"RubricGenerationConfig",
"PredefinedRubricMetrics",
]

View File

@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Base classes for evaluation metrics."""
import abc
from typing import Any, Callable, Dict, Literal, Optional, Union, List
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai import generative_models
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation.metrics import (
custom_output_config as custom_output_config_class,
)
from vertexai.preview.evaluation.metrics import (
metric_prompt_template as metric_prompt_template_base,
)
_ModelType = Union[generative_models.GenerativeModel, Callable[[str], str]]
class _Metric(abc.ABC):
"""The abstract class for evaluation metric."""
def __init__(self, metric: str):
self._metric = metric
def __str__(self):
return self.metric_name
@property
def metric_name(self) -> str:
return self._metric
class _ModelBasedMetric(_Metric):
"""A Model-based Metric.
An evaluation metric that evaluates generative AI model responses with
another generative model as a judge. This metric can be used to evaluate a
single model, or two models side-by-side.
For more details on when to use model-based metrics, see
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
"""
def __init__(
self,
*,
metric: str,
metric_prompt_template: Union[
metric_prompt_template_base.PointwiseMetricPromptTemplate,
metric_prompt_template_base.PairwiseMetricPromptTemplate,
str,
],
system_instruction: Optional[str] = None,
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
custom_output_config: Optional[
custom_output_config_class.CustomOutputConfig
] = None,
):
"""Initializes the model-based evaluation metric.
Args:
metric: Generic model based metric name.
metric_prompt_template: A metric prompt template for performing
the model-based evaluation. A freeform string is also accepted.
system_instruction: The system instruction to be used in the metric
prompt.
autorater_config: The config for judge model.
custom_output_config: Config for custom output from the judge model.
"""
super().__init__(metric=metric)
self.metric_prompt_template = str(metric_prompt_template)
self.system_instruction = system_instruction
self.autorater_config = autorater_config
self.custom_output_config = custom_output_config
class CustomMetric(_Metric):
"""The custom evaluation metric.
A fully-customized CustomMetric that can be used to evaluate a single model
by defining a metric function for a computation-based metric. The
CustomMetric is computed on the client-side using the user-defined metric
function in SDK only, not by the Vertex Gen AI Evaluation Service.
Attributes:
name: The name of the metric.
metric_function: The user-defined evaluation function to compute a metric
score. Must use the dataset row dictionary as the metric function
input and return per-instance metric result as a dictionary output.
The metric score must mapped to the name of the CustomMetric as key.
"""
def __init__(
self,
name: str,
metric_function: Callable[
[Dict[str, Any]],
Dict[str, Any],
],
):
"""Initializes the evaluation metric."""
super().__init__(name)
self.name = name
self.metric_function = metric_function
class _AutomaticMetric(_Metric):
"""An automatic metric that computes deterministic score based on reference.
An lexicon-based evaluation metric that evaluate a generative model's
response on the given evaluation task with reference ground truth answers.
It is a type of pointwise evaluation metric.
For more details on when to use automatic metrics, see
[Evaluation methods and
metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
"""
def __init__(
self,
metric: Literal[constants.Metric.ROUGE],
):
"""Initializes the automatic evaluation metric.
Args:
metric: The automatic evaluation metric name.
"""
super().__init__(metric=metric)
class RubricGenerationConfig:
"""The rubric generation config."""
def __init__(
self,
prompt_template: str,
model: Optional[_ModelType] = None,
parsing_fn: Optional[Callable[[str], List[str]]] = None,
):
"""Initializes the rubric generation config.
Args:
prompt_template: The prompt template for rubric generation.
model: The model to use for rubric generation.
parsing_fn: The function to parse the rubric generation response.
"""
self.prompt_template = prompt_template
self.model = model
self.parsing_fn = parsing_fn

View File

@@ -0,0 +1,802 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Library for metrics computation with Gen AI Evaluation Service."""
import json
from typing import Any, Dict, List, Union
from google import api_core
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1beta1.services import (
evaluation_service as gapic_evaluation_services,
)
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview.evaluation import _base as eval_base
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation import multimodal_utils
from vertexai.preview.evaluation import (
prompt_template as prompt_template_base,
)
from vertexai.preview.evaluation import utils
from vertexai.preview.evaluation.metrics import (
_base as metrics_base,
)
from vertexai.preview.evaluation.metrics import (
_default_templates,
)
from vertexai.preview.evaluation.metrics import _rouge
from vertexai.preview.evaluation.metrics import (
_trajectory_single_tool_use,
)
from vertexai.preview.evaluation.metrics import (
custom_output_config as custom_output_config_class,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
from google.protobuf import json_format
_LOGGER = base.Logger(__name__)
_METRIC_NAME_TO_METRIC_SPEC = {
# Automatic Metrics.
constants.Metric.EXACT_MATCH: (gapic_eval_service_types.ExactMatchSpec()),
constants.Metric.BLEU: gapic_eval_service_types.BleuSpec(),
constants.Metric.ROUGE: gapic_eval_service_types.RougeSpec(),
constants.Metric.ROUGE_1: gapic_eval_service_types.RougeSpec(rouge_type="rouge1"),
constants.Metric.ROUGE_2: gapic_eval_service_types.RougeSpec(rouge_type="rouge2"),
constants.Metric.ROUGE_L: gapic_eval_service_types.RougeSpec(rouge_type="rougeL"),
constants.Metric.ROUGE_L_SUM: gapic_eval_service_types.RougeSpec(
rouge_type="rougeLsum"
),
constants.Metric.TOOL_CALL_VALID: (gapic_eval_service_types.ToolCallValidSpec()),
constants.Metric.TOOL_NAME_MATCH: (gapic_eval_service_types.ToolNameMatchSpec()),
constants.Metric.TOOL_PARAMETER_KV_MATCH: (
gapic_eval_service_types.ToolParameterKVMatchSpec()
),
constants.Metric.TOOL_PARAMETER_KEY_MATCH: (
gapic_eval_service_types.ToolParameterKeyMatchSpec()
),
# Pointwise Metrics.
constants.Metric.POINTWISE_METRIC: (gapic_eval_service_types.PointwiseMetricSpec()),
# Pairwise Metrics.
constants.Metric.PAIRWISE_METRIC: (gapic_eval_service_types.PairwiseMetricSpec()),
constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING: (
gapic_eval_service_types.RubricBasedInstructionFollowingSpec()
),
constants.Metric.TRAJECTORY_EXACT_MATCH: (
gapic_eval_service_types.TrajectoryExactMatchSpec()
),
constants.Metric.TRAJECTORY_IN_ORDER_MATCH: (
gapic_eval_service_types.TrajectoryInOrderMatchSpec()
),
constants.Metric.TRAJECTORY_ANY_ORDER_MATCH: (
gapic_eval_service_types.TrajectoryAnyOrderMatchSpec()
),
constants.Metric.TRAJECTORY_PRECISION: (
gapic_eval_service_types.TrajectoryPrecisionSpec()
),
constants.Metric.TRAJECTORY_RECALL: (
gapic_eval_service_types.TrajectoryRecallSpec()
),
constants.Metric.TRAJECTORY_SINGLE_TOOL_USE: (
gapic_eval_service_types.TrajectorySingleToolUseSpec()
),
}
_QUESTION_TEMPLATE = """<question>{question}"""
def _format_rubrics(questions: List[str]) -> str:
"""Formats the list of rubrics into a question block."""
question_block = "\n".join(
_QUESTION_TEMPLATE.format(question=q.strip()) for q in questions
)
return question_block
def build_custom_output_format_config(
custom_output_config: custom_output_config_class.CustomOutputConfig,
) -> Union[gapic_eval_service_types.CustomOutputFormatConfig, None]:
"""Builds a CustomOutputFormatConfig from user input."""
custom_output_cfg = gapic_eval_service_types.CustomOutputFormatConfig()
if custom_output_config.return_raw_output:
custom_output_cfg.return_raw_output = True
return custom_output_cfg
else:
return None
def build_trajectory(
trajectory: Union[str, List[Dict[str, Any]]],
) -> gapic_eval_service_types.Trajectory:
"""Builds a trajectory from user input."""
if not trajectory:
return
if isinstance(trajectory, str):
trajectory = json.loads(trajectory)
if isinstance(trajectory, List):
try:
tool_calls = []
for tool_call_dict in trajectory:
tool_input_str = json.dumps(tool_call_dict["tool_input"])
tool_calls.append(
gapic_eval_service_types.ToolCall(
tool_name=tool_call_dict["tool_name"], tool_input=tool_input_str
)
)
return gapic_eval_service_types.Trajectory(tool_calls=tool_calls)
except KeyError as e:
_LOGGER.error(f"Failed to parse trajectory: {e}")
else:
_LOGGER.error(
f"Unsupported trajectory type: {type(trajectory)}, expected list or"
" a JSON array."
)
def build_request(
metric: Union[str, metrics_base._Metric],
row_dict: Dict[str, Any],
evaluation_run_config: eval_base.EvaluationRunConfig,
) -> gapic_eval_service_types.EvaluateInstancesRequest:
"""Builds a metric instance and form the request for the evaluation service.
Args:
metric: The name of the metric to evaluate.
row_dict: An evaluation dataset instance as a dictionary.
evaluation_run_config: Evaluation run configurations.
Returns:
A single EvaluateInstancesRequest.
Raises:
ValueError: If required request fields are not provided.
"""
project = initializer.global_config.project
location = initializer.global_config.location
if not project or not location:
raise ValueError(
"No project or location specified. Please run `vertexai.init()` to"
" provide these parameters."
)
location_path = (
gapic_evaluation_services.EvaluationServiceClient.common_location_path(
project, location
)
)
if isinstance(metric, pointwise_metric.PointwiseMetric):
metric_name = constants.Metric.POINTWISE_METRIC
elif isinstance(metric, pairwise_metric.PairwiseMetric):
metric_name = constants.Metric.PAIRWISE_METRIC
else:
metric_name = str(metric)
try:
metric_spec = _METRIC_NAME_TO_METRIC_SPEC[metric_name]
except KeyError as e:
raise ValueError(f"Metric name: {metric_name} is not supported.") from e
model_based_metric_instance_input = {}
metric_column_mapping = evaluation_run_config.metric_column_mapping
if isinstance(
metric, metrics_base._ModelBasedMetric # pylint: disable=protected-access
):
metric_spec.metric_prompt_template = metric.metric_prompt_template
metric_spec.system_instruction = metric.system_instruction
if metric.custom_output_config:
metric_spec.custom_output_format_config = build_custom_output_format_config(
metric.custom_output_config
)
for variable in prompt_template_base.PromptTemplate(
metric.metric_prompt_template
).variables:
model_based_metric_instance_input[variable] = row_dict.get(
metric_column_mapping.get(variable),
"",
)
if isinstance(metric, pairwise_metric.PairwiseMetric):
metric_column_mapping = evaluation_run_config.metric_column_mapping
metric_spec.candidate_response_field_name = metric_column_mapping.get(
constants.Dataset.MODEL_RESPONSE_COLUMN,
constants.Dataset.MODEL_RESPONSE_COLUMN,
)
metric_spec.baseline_response_field_name = metric_column_mapping.get(
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
)
elif isinstance(metric, _rouge.Rouge):
metric_spec.rouge_type = metric.rouge_type
metric_spec.use_stemmer = metric.use_stemmer
metric_spec.split_summaries = metric.split_summaries
elif isinstance(metric, _trajectory_single_tool_use.TrajectorySingleToolUse):
metric_spec.tool_name = metric.tool_name
response = row_dict.get(
metric_column_mapping.get(constants.Dataset.MODEL_RESPONSE_COLUMN), ""
)
reference = row_dict.get(
metric_column_mapping.get(constants.Dataset.REFERENCE_COLUMN), ""
)
predicted_trajectory = build_trajectory(
row_dict.get(
metric_column_mapping.get(constants.Dataset.PREDICTED_TRAJECTORY_COLUMN),
"",
)
)
reference_trajectory = build_trajectory(
row_dict.get(
metric_column_mapping.get(constants.Dataset.REFERENCE_TRAJECTORY_COLUMN),
"",
)
)
if isinstance(metric, metrics_base._ModelBasedMetric):
if metric_spec.metric_prompt_template in (
_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
_default_templates.TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
_default_templates.PAIRWISE_MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
_default_templates.PAIRWISE_TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
):
model_based_metric_instance_input[
constants.Dataset.RUBRICS_COLUMN
] = _format_rubrics(
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN]
)
if (
constants.Dataset.RUBRICS_COLUMN in model_based_metric_instance_input
and isinstance(
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN],
List,
)
):
model_based_metric_instance_input[
constants.Dataset.RUBRICS_COLUMN
] = "\n".join(
model_based_metric_instance_input[constants.Dataset.RUBRICS_COLUMN]
)
if metric_name == constants.Metric.EXACT_MATCH:
instance = gapic_eval_service_types.ExactMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.ExactMatchInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
exact_match_input=instance,
)
elif metric_name == constants.Metric.BLEU:
instance = gapic_eval_service_types.BleuInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.BleuInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
bleu_input=instance,
)
elif metric_name in (
constants.Metric.ROUGE,
constants.Metric.ROUGE_1,
constants.Metric.ROUGE_2,
constants.Metric.ROUGE_L,
constants.Metric.ROUGE_L_SUM,
):
instance = gapic_eval_service_types.RougeInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.RougeInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
rouge_input=instance,
)
elif metric_name == constants.Metric.TOOL_CALL_VALID:
instance = gapic_eval_service_types.ToolCallValidInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.ToolCallValidInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
tool_call_valid_input=instance,
)
elif metric_name == constants.Metric.TOOL_NAME_MATCH:
instance = gapic_eval_service_types.ToolNameMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.ToolNameMatchInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
tool_name_match_input=instance,
)
elif metric_name == constants.Metric.TOOL_PARAMETER_KEY_MATCH:
instance = gapic_eval_service_types.ToolParameterKeyMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.ToolParameterKeyMatchInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
tool_parameter_key_match_input=instance,
)
elif metric_name == constants.Metric.TOOL_PARAMETER_KV_MATCH:
instance = gapic_eval_service_types.ToolParameterKVMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.ToolParameterKVMatchInstance(
prediction=response,
reference=reference,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
tool_parameter_kv_match_input=instance,
)
elif metric_name == constants.Metric.POINTWISE_METRIC:
if multimodal_utils.is_multimodal_instance(model_based_metric_instance_input):
instance = gapic_eval_service_types.PointwiseMetricInput(
metric_spec=metric_spec,
instance=gapic_eval_service_types.PointwiseMetricInstance(
content_map_instance=multimodal_utils.convert_multimodal_response_to_content_map(
model_based_metric_instance_input
),
),
)
else:
instance = gapic_eval_service_types.PointwiseMetricInput(
metric_spec=metric_spec,
instance=gapic_eval_service_types.PointwiseMetricInstance(
json_instance=json.dumps(model_based_metric_instance_input),
),
)
autorater_config = evaluation_run_config.autorater_config
if (
isinstance(metric, metrics_base._ModelBasedMetric)
and metric.autorater_config
):
autorater_config = metric.autorater_config
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
pointwise_metric_input=instance,
autorater_config=autorater_config,
)
elif metric_name == constants.Metric.PAIRWISE_METRIC:
if multimodal_utils.is_multimodal_instance(model_based_metric_instance_input):
instance = gapic_eval_service_types.PairwiseMetricInput(
metric_spec=metric_spec,
instance=gapic_eval_service_types.PairwiseMetricInstance(
content_map_instance=multimodal_utils.convert_multimodal_response_to_content_map(
model_based_metric_instance_input
),
),
)
else:
instance = gapic_eval_service_types.PairwiseMetricInput(
metric_spec=metric_spec,
instance=gapic_eval_service_types.PairwiseMetricInstance(
json_instance=json.dumps(model_based_metric_instance_input),
),
)
autorater_config = evaluation_run_config.autorater_config
if (
isinstance(metric, metrics_base._ModelBasedMetric)
and metric.autorater_config
):
autorater_config = metric.autorater_config
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
pairwise_metric_input=instance,
autorater_config=autorater_config,
)
elif metric_name == constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING:
required_rbif_fields = [
constants.Dataset.MODEL_RESPONSE_COLUMN,
constants.Dataset.PROMPT_COLUMN,
]
for field in required_rbif_fields:
column_name = metric_column_mapping.get(field)
value = row_dict.get(column_name)
if value is None and field in required_rbif_fields:
raise ValueError(
f"Missing required field: `{field}` for "
f"{constants.Metric.RUBRIC_BASED_INSTRUCTION_FOLLOWING}."
)
else:
model_based_metric_instance_input[field] = value
instance = gapic_eval_service_types.RubricBasedInstructionFollowingInput(
metric_spec=metric_spec,
instance=gapic_eval_service_types.RubricBasedInstructionFollowingInstance(
json_instance=json.dumps(model_based_metric_instance_input),
),
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
rubric_based_instruction_following_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_EXACT_MATCH:
instance = gapic_eval_service_types.TrajectoryExactMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectoryExactMatchInstance(
predicted_trajectory=predicted_trajectory,
reference_trajectory=reference_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_exact_match_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_IN_ORDER_MATCH:
instance = gapic_eval_service_types.TrajectoryInOrderMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectoryInOrderMatchInstance(
predicted_trajectory=predicted_trajectory,
reference_trajectory=reference_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_in_order_match_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_ANY_ORDER_MATCH:
instance = gapic_eval_service_types.TrajectoryAnyOrderMatchInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectoryAnyOrderMatchInstance(
predicted_trajectory=predicted_trajectory,
reference_trajectory=reference_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_any_order_match_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_PRECISION:
instance = gapic_eval_service_types.TrajectoryPrecisionInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectoryPrecisionInstance(
predicted_trajectory=predicted_trajectory,
reference_trajectory=reference_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_precision_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_RECALL:
instance = gapic_eval_service_types.TrajectoryRecallInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectoryRecallInstance(
predicted_trajectory=predicted_trajectory,
reference_trajectory=reference_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_recall_input=instance,
)
elif metric_name == constants.Metric.TRAJECTORY_SINGLE_TOOL_USE:
instance = gapic_eval_service_types.TrajectorySingleToolUseInput(
metric_spec=metric_spec,
instances=[
gapic_eval_service_types.TrajectorySingleToolUseInstance(
predicted_trajectory=predicted_trajectory,
)
],
)
return gapic_eval_service_types.EvaluateInstancesRequest(
location=location_path,
trajectory_single_tool_use_input=instance,
)
else:
raise ValueError(f"Unknown metric type: {metric_name}")
def _parse_autometric_results(
metric_result_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""Parses the automatic metric results from the evaluation results.
Args:
metric_result_dict: The metric results dictionary.
Returns:
A dictionary containing metric score of the metric.
"""
for value in metric_result_dict.values():
return {
constants.MetricResult.SCORE_KEY: value[0].get(
constants.MetricResult.SCORE_KEY
)
}
def _parse_pointwise_results(
metric_result_dict: Dict[str, Any],
metric: Union[str, metrics_base._Metric],
) -> Dict[str, Any]:
"""Parses the model-based pointwise metric results from the evaluation results.
Args:
metric_result_dict: The metric results dictionary.
metric: The metric to evaluate.
Returns:
One of the following:
1. A dictionary containing raw outputs from the judge model if
return_raw_output is set to True in custom_output_config.
2. A dictionary containing metric score and explanation of the
metric if custom_output_config is not set.
"""
if (
isinstance(metric, pointwise_metric.PointwiseMetric)
and getattr(metric, "custom_output_config", None)
and getattr(metric.custom_output_config, "return_raw_output", False)
):
raw_outputs = (
metric_result_dict.get(constants.MetricResult.CUSTOM_OUTPUT_KEY)
.get(constants.MetricResult.RAW_OUTPUTS_KEY)
.get(constants.MetricResult.RAW_OUTPUT_KEY)
)
if (
isinstance(metric, pointwise_metric.PointwiseMetric)
and getattr(metric, "custom_output_config", None)
and getattr(metric.custom_output_config, "parsing_fn", None)
):
parsing_fn = metric.custom_output_config.parsing_fn
return parsing_fn(raw_outputs)
return {constants.MetricResult.RAW_OUTPUT_KEY: raw_outputs}
else:
return {
constants.MetricResult.SCORE_KEY: metric_result_dict.get(
constants.MetricResult.SCORE_KEY
),
constants.MetricResult.EXPLANATION_KEY: metric_result_dict.get(
constants.MetricResult.EXPLANATION_KEY
),
}
def _parse_pairwise_results(
metric_result_dict: Dict[str, Any],
metric: Union[str, metrics_base._Metric],
) -> Dict[str, Any]:
"""Parses the pairwise metric results from the evaluation results.
Args:
metric_result_dict: The metric results dictionary.
metric: The metric to evaluate.
Returns:
One of the following:
1. A dictionary containing raw outputs from the judge model if
return_raw_output is set to True in custom_output_config.
2. A dictionary containing metric score and explanation of the
metric if custom_output_config is not set.
"""
if (
isinstance(metric, pairwise_metric.PairwiseMetric)
and getattr(metric, "custom_output_config", None)
and getattr(metric.custom_output_config, "return_raw_output", False)
):
raw_outputs = (
metric_result_dict.get(constants.MetricResult.CUSTOM_OUTPUT_KEY)
.get(constants.MetricResult.RAW_OUTPUTS_KEY)
.get(constants.MetricResult.RAW_OUTPUT_KEY)
)
if (
isinstance(metric, pairwise_metric.PairwiseMetric)
and getattr(metric, "custom_output_config", None)
and getattr(metric.custom_output_config, "parsing_fn", None)
):
parsing_fn = metric.custom_output_config.parsing_fn
return parsing_fn(raw_outputs)
return {constants.MetricResult.RAW_OUTPUT_KEY: raw_outputs}
else:
return {
constants.MetricResult.PAIRWISE_CHOICE_KEY: metric_result_dict.get(
constants.MetricResult.PAIRWISE_CHOICE_KEY,
),
constants.MetricResult.EXPLANATION_KEY: metric_result_dict.get(
constants.MetricResult.EXPLANATION_KEY
),
}
def _parse_rubric_based_instruction_following_results(
metric_result_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""Parses the rubric-based instruction following metric results from the evaluation results.
Args:
metric_result_dict: The metric results dictionary.
Returns:
A dictionary containing a list of rubrics and corresponding verdicts and
an overall instruction following score.
"""
rubric_critique_results = []
for rc_result in metric_result_dict["rubric_critique_results"]:
if "verdict" not in rc_result:
rc_result["verdict"] = False # proto3 shows False bool as unset
rubric_critique_results.append(
{
"rubric": rc_result["rubric"],
"verdict": rc_result["verdict"],
}
)
return {
constants.MetricResult.RUBRIC_LEVEL_INSTRUCTION_FOLLOWING_KEY: (
rubric_critique_results
),
constants.MetricResult.SCORE_KEY: (
metric_result_dict.get(constants.MetricResult.SCORE_KEY)
),
}
def handle_response(
response: Union[str, gapic_eval_service_types.EvaluateInstancesResponse],
metric: Union[str, metrics_base._Metric],
) -> Union[str, Dict[str, Any]]:
"""Handles the response from the evaluation service.
Args:
response: The response from the evaluation service.
metric: The metric to evaluate to check the output type.
Returns:
A parsed metric result dictionary, or an error message string.
"""
if isinstance(response, str):
return response
metric_type = response._pb.WhichOneof( # pylint: disable=protected-access
"evaluation_results"
)
if metric_type == constants.MetricResult.EXACT_MATCH_RESULTS:
metric_result = response.exact_match_results
elif metric_type == constants.MetricResult.BLEU_RESULTS:
metric_result = response.bleu_results
elif metric_type == constants.MetricResult.ROUGE_RESULTS:
metric_result = response.rouge_results
elif metric_type == constants.MetricResult.TOOL_CALL_VALID_RESULTS:
metric_result = response.tool_call_valid_results
elif metric_type == constants.MetricResult.TOOL_NAME_MATCH_RESULTS:
metric_result = response.tool_name_match_results
elif metric_type == constants.MetricResult.TOOL_PARAMETER_KEY_MATCH_RESULTS:
metric_result = response.tool_parameter_key_match_results
elif metric_type == constants.MetricResult.TOOL_PARAMETER_KV_MATCH_RESULTS:
metric_result = response.tool_parameter_kv_match_results
elif metric_type == constants.MetricResult.POINTWISE_METRIC_RESULT:
metric_result = response.pointwise_metric_result
elif metric_type == constants.MetricResult.PAIRWISE_METRIC_RESULT:
metric_result = response.pairwise_metric_result
elif metric_type == constants.MetricResult.TRAJECTORY_EXACT_MATCH_RESULTS:
metric_result = response.trajectory_exact_match_results
elif metric_type == constants.MetricResult.TRAJECTORY_IN_ORDER_MATCH_RESULTS:
metric_result = response.trajectory_in_order_match_results
elif metric_type == constants.MetricResult.TRAJECTORY_ANY_ORDER_MATCH_RESULTS:
metric_result = response.trajectory_any_order_match_results
elif metric_type == constants.MetricResult.TRAJECTORY_PRECISION_RESULTS:
metric_result = response.trajectory_precision_results
elif metric_type == constants.MetricResult.TRAJECTORY_RECALL_RESULTS:
metric_result = response.trajectory_recall_results
elif metric_type == constants.MetricResult.TRAJECTORY_SINGLE_TOOL_USE_RESULTS:
metric_result = response.trajectory_single_tool_use_results
elif (
metric_type == constants.MetricResult.RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT
):
metric_result = response.rubric_based_instruction_following_result
else:
raise ValueError(f"Unknown metric type: {metric_type}")
metric_result_dict = json_format.MessageToDict(
metric_result._pb, # pylint: disable=protected-access
preserving_proto_field_name=True,
)
if metric_type in (constants.MetricResult.AUTOMATIC_METRIC_RESULTS_LIST):
result = _parse_autometric_results(metric_result_dict)
elif metric_type == constants.MetricResult.POINTWISE_METRIC_RESULT:
result = _parse_pointwise_results(metric_result_dict, metric)
elif metric_type == constants.MetricResult.PAIRWISE_METRIC_RESULT:
result = _parse_pairwise_results(metric_result_dict, metric)
elif (
metric_type == constants.MetricResult.RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT
):
result = _parse_rubric_based_instruction_following_results(metric_result_dict)
else:
raise ValueError(f"Unknown metric type: {metric_type}")
return result
def evaluate_instances(
client: gapic_evaluation_services.EvaluationServiceClient,
request: gapic_eval_service_types.EvaluateInstancesRequest,
rate_limiter: utils.RateLimiter,
retry_timeout: float,
) -> gapic_eval_service_types.EvaluateInstancesResponse:
"""Evaluates an instance using Vertex Gen AI Evaluation Service.
Args:
client: The Vertex Gen AI evaluation service client for evaluation.
request: An EvaluateInstancesRequest.
rate_limiter: The rate limiter for evaluation service requests.
retry_timeout: How long to keep retrying the evaluation requests, in seconds.
Returns:
An EvaluateInstancesResponse from Vertex Gen AI Evaluation Service.
"""
rate_limiter.sleep_and_advance()
return client.evaluate_instances(
request=request,
retry=api_core.retry.Retry(
initial=0.250,
maximum=90.0,
multiplier=1.45,
timeout=retry_timeout,
predicate=api_core.retry.if_exception_type(
api_core.exceptions.Aborted,
api_core.exceptions.DeadlineExceeded,
api_core.exceptions.ResourceExhausted,
api_core.exceptions.ServiceUnavailable,
api_core.exceptions.Cancelled,
),
),
)

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""ROUGE Metric."""
from typing import Literal
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation.metrics import _base
class Rouge(_base._AutomaticMetric): # pylint: disable=protected-access
"""The ROUGE Metric.
Calculates the recall of n-grams in prediction as compared to reference and
returns a score ranging between 0 and 1. Supported rouge types are
rougen[1-9], rougeL, and rougeLsum.
"""
_metric_name = constants.Metric.ROUGE
def __init__(
self,
*,
rouge_type: Literal[
"rouge1",
"rouge2",
"rouge3",
"rouge4",
"rouge5",
"rouge6",
"rouge7",
"rouge8",
"rouge9",
"rougeL",
"rougeLsum",
],
use_stemmer: bool = False,
split_summaries: bool = False
):
"""Initializes the ROUGE metric.
Args:
rouge_type: Supported rouge types are rougen[1-9], rougeL, and rougeLsum.
use_stemmer: Whether to use stemmer to compute rouge score.
split_summaries: Whether to split summaries while using 'rougeLsum' to
compute rouge score.
"""
self._rouge_type = rouge_type
self._use_stemmer = use_stemmer
self._split_summaries = split_summaries
super().__init__(
metric=Rouge._metric_name,
)
@property
def rouge_type(self) -> str:
return self._rouge_type
@property
def use_stemmer(self) -> bool:
return self._use_stemmer
@property
def split_summaries(self) -> bool:
return self._split_summaries

View File

@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Schema for autorater metric configuration."""
AUTORATER_METRIC_SCHEMA = """
$schema: https://json-schema.org/draft/2020-12/schema
title: AutoRater Metric Configuration
description: A metric definition for model-based evaluation.
type: object
properties:
metadata:
description: Useful information about the metric.
type: object
properties:
name:
description: Name of the metric.
type: string
description:
description: Description of the metric.
type: string
author:
description: Author of the metric.
type: string
contact:
description: PoC for the metric.
type: string
version:
description: Version of the metric.
type: string
classification:
description: Classification of the metric.
type: string
enum:
- experimental
- benchmarked
- deprecated
required_inputs:
description: Input fields used in the metric prompt template.
type: array
items:
type: string
minItems: 1
uniqueItems: true
benchmarks:
description: List of benchmarks used for the metric.
type: array
items:
type: object
properties:
dataset:
description: Dataset used for benchmarking.
type: string
results:
description: Results from benchmarking.
type: string
required:
- results
minItems: 1
uniqueItems: true
usage:
description: Links to documentation or notebooks with example usage.
type: array
items:
type: string
minItems: 1
uniqueItems: true
required:
- name
- version
- required_inputs
steps:
description: List of steps used for the autorater workflow.
type: array
items:
type: object
properties:
type:
description: Type of the step.
type: string
enum:
- pointwise_metric
- pairwise_metric
- rubric
prompt:
description: Prompt template for the step.
type: object
properties:
system_instruction:
description: System instruction for the model.
type: string
template:
description: Template to populate with inputs from the dataset.
type: string
required:
- template
model:
description: Configuration of the model for the step.
type: object
properties:
model_name_or_endpoint:
description: Name or endpoint of the model.
type: string
required:
- model_name_or_endpoint
options:
description: Options for the step.
type: object
properties:
sample_count:
description: Number of samples for each instance in the dataset.
type: integer
flip_enabled:
description: Whether to flip candidate and baseline responses.
type: boolean
output:
description: Output of the step.
type: object
properties:
type:
description: Type of the output.
type: string
enum:
- raw
required:
- type
required:
- type
- prompt
minItems: 1
uniqueItems: true
required:
- metadata
- steps
"""

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation.metrics import _base
class TrajectorySingleToolUse(
_base._AutomaticMetric
): # pylint: disable=protected-access
"""The TrajectorySingleToolUse Metric.
Evaluates if a tool is present in the trajectory or not.
"""
_metric_name = constants.Metric.TRAJECTORY_SINGLE_TOOL_USE
def __init__(
self,
tool_name: str,
):
"""Initializes the TrajectorySingleToolUse metric.
Args:
tool_name: name of the tool to check.
"""
self._tool_name = tool_name
super().__init__(
metric=TrajectorySingleToolUse._metric_name,
)
@property
def tool_name(self) -> str:
return self._tool_name

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Custom output config for model-based metrics."""
from typing import Any, Callable, Dict, Optional
class CustomOutputConfig:
"""Custom output config for model-based metrics.
Attributes:
return_raw_output: Whether to return the raw output of the metric
function.
parsing_fn: Function to parse the raw output of the metric.
"""
def __init__(
self,
return_raw_output: bool = False,
parsing_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
):
"""Initializes CustomOutputConfig."""
self.return_raw_output = return_raw_output
self.parsing_fn = parsing_fn

View File

@@ -0,0 +1,395 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Metric prompt template classes for model-based metrics evaluation."""
from typing import Dict, List, Optional
from google.cloud.aiplatform import base
from vertexai.preview.evaluation import (
prompt_template,
)
_LOGGER = base.Logger(__name__)
_NEWLINE = "\n"
def serialize_dict_in_order(elements: Optional[Dict[str, str]]):
"""Serializes dictionary to ordered string value without brackets."""
if elements is None:
return ""
return _NEWLINE.join(f"{key}: {value}" for key, value in sorted(elements.items()))
class _MetricPromptTemplate(prompt_template.PromptTemplate):
"""Metric prompt template for generic model-based metrics evaluation."""
def __init__(
self,
*,
criteria: Dict[str, str],
rating_rubric: Dict[str, str],
input_variables: List[str],
instruction: Optional[str] = None,
evaluation_steps: Optional[Dict[str, str]] = None,
metric_definition: Optional[str] = None,
few_shot_examples: Optional[List[str]] = None,
):
"""Initializes a metric prompt template."""
self._input_variables = input_variables
self._instruction = instruction
self._metric_definition = metric_definition
self._criteria = criteria
self._rating_rubric = rating_rubric
self._evaluation_steps = evaluation_steps
self._few_shot_examples = few_shot_examples
self.template = self.__str__()
@property
def prompt_data(self) -> str:
return self.template
class PointwiseMetricPromptTemplate(_MetricPromptTemplate):
"""Pointwise metric prompt template for pointwise model-based metrics."""
def __init__(
self,
*,
criteria: Dict[str, str],
rating_rubric: Dict[str, str],
input_variables: Optional[List[str]] = None,
instruction: Optional[str] = None,
metric_definition: Optional[str] = None,
evaluation_steps: Optional[Dict[str, str]] = None,
few_shot_examples: Optional[List[str]] = None,
):
"""Initializes a pointwise metric prompt template.
Args:
criteria: The standards and measures used to evaluate the model
responses. It is a dictionary of criterion names and criterion
definitions.
rating_rubric: A dictionary mapping of rating name and rating
definition, used to assign ratings or scores based on specific
criteria.
input_variables: An optional list of input fields to use in the metric
prompt template for generating model-based evaluation results. Model
"response" column is included by default. If metric_column_mapping is
provided, the mapping values of the input fields will be used to
retrieve data from the evaluation dataset.
instruction: The general instruction to the model that performs the
evaluation. If not provided, a default pointwise metric instruction
will be used.
metric_definition: The optional metric definition. It is a string
describing the metric to be evaluated at a high level. If not
provided, this field will not be included in the prompt template.
evaluation_steps: The optional gudelines of evaluation steps. A
dictionary of evaluation step name and evaluation step definition. If
not provided, a default pointwise metric evaluation steps will be
used.
few_shot_examples: The optional list of few-shot examples to be used in
the prompt, to provide the model with demonstrations of how to perform
the evaluation, and improve the evaluation accuracy. If not provided,
this field will not be included in the prompt template.
"""
if not input_variables:
input_variables = []
_LOGGER.info(
"The `input_variables` parameter is empty. Only the `response`"
" column is used for computing this model-based metric."
)
input_variables = list(set(input_variables + ["response"]))
instruction = instruction or self.get_default_pointwise_instruction()
evaluation_steps = (
evaluation_steps or self.get_default_pointwise_evaluation_steps()
)
super().__init__(
input_variables=input_variables,
criteria=criteria,
rating_rubric=rating_rubric,
instruction=instruction,
metric_definition=metric_definition,
evaluation_steps=evaluation_steps,
few_shot_examples=few_shot_examples,
)
def get_default_pointwise_instruction(self) -> str:
"""Returns the default instruction for the metric prompt template."""
return (
"You are an expert evaluator. Your task is to evaluate the quality of"
" the responses generated by AI models. We will provide you with the"
" user prompt and an AI-generated responses.\nYou should first read"
" the user input carefully for analyzing the task, and then evaluate"
" the quality of the responses based on the Criteria provided in the"
" Evaluation section below.\nYou will assign the response a rating"
" following the Rating Rubric and Evaluation Steps. Give step by step"
" explanations for your rating, and only choose ratings from the Rating"
" Rubric."
)
def get_default_pointwise_evaluation_steps(self) -> Dict[str, str]:
"""Returns the default evaluation steps for the metric prompt template."""
return {
"Step 1": (
"Assess the response in aspects of all criteria provided. Provide"
" assessment according to each criterion."
),
"Step 2": (
"Score based on the rating rubric. Give a brief rationale to"
" explain your evaluation considering each individual criterion."
),
}
def __str__(self):
"""Serializes the pointwise metric prompt template to a string."""
metric_prompt_template_str = [
"# Instruction",
f"{self._instruction}",
_NEWLINE,
"# Evaluation",
]
if self._metric_definition:
metric_prompt_template_str.extend(
[
"## Metric Definition",
f"{self._metric_definition}\n",
]
)
metric_prompt_template_str.extend(
[
"## Criteria",
f"{serialize_dict_in_order(self._criteria)}\n",
"## Rating Rubric",
f"{serialize_dict_in_order(self._rating_rubric)}\n",
]
)
if self._evaluation_steps:
metric_prompt_template_str.extend(
[
"## Evaluation Steps",
f"{serialize_dict_in_order(self._evaluation_steps)}\n",
]
)
if self._few_shot_examples:
metric_prompt_template_str.extend(
[
"## Evaluation Examples",
f"{_NEWLINE.join(self._few_shot_examples)}\n",
]
)
metric_prompt_template_str.extend(
["\n# User Inputs and AI-generated Response", "## User Inputs"]
)
for input_variable in self._input_variables:
if input_variable == "response":
continue
metric_prompt_template_str.extend(
[
f"### {input_variable}",
f"{{{input_variable}}}\n",
]
)
metric_prompt_template_str.extend(
[
_NEWLINE,
"\n## AI-generated Response",
"{response}",
]
)
return _NEWLINE.join(metric_prompt_template_str)
def __repr__(self):
return (
f"PointwiseMetricPromptTemplate(prompt_data={self.prompt_data},"
f" variables={self.variables})"
)
class PairwiseMetricPromptTemplate(_MetricPromptTemplate):
"""Pairwise metric prompt template for pairwise model-based metrics."""
def __init__(
self,
*,
criteria: Dict[str, str],
rating_rubric: Dict[str, str],
input_variables: Optional[List[str]] = None,
instruction: Optional[str] = None,
metric_definition: Optional[str] = None,
evaluation_steps: Optional[Dict[str, str]] = None,
few_shot_examples: Optional[List[str]] = None,
):
"""Initializes a pairwise metric prompt template.
Args:
criteria: The standards and measures used to evaluate the model
responses. It is a dictionary of criterion names and criterion
definitions.
rating_rubric: A dictionary mapping of rating name and rating
definition, used to assign ratings or scores based on specific
criteria.
input_variables: An optional list of input fields to use in the metric
prompt template for generating model-based evaluation results.
Candidate model "response" column and "baseline_model_response" column
are included by default. If metric_column_mapping is provided, the
mapping values of the input fields will be used to retrieve data from
the evaluation dataset.
instruction: The general instruction to the model that performs the
evaluation. If not provided, a default pairwise metric instruction
will be used.
metric_definition: The optional metric definition. It is a string
describing the metric to be evaluated at a high level. If not
provided, this field will not be included in the prompt template.
evaluation_steps: The optional gudelines of evaluation steps. A
dictionary of evaluation step name and evaluation step definition. If
not provided, a default pairwise metric evaluation steps will be used.
few_shot_examples: The optional list of few-shot examples to be used in
the prompt, to provide the model with demonstrations of how to perform
the evaluation, and improve the evaluation accuracy. If not provided,
this field will not be included in the prompt template.
"""
if not input_variables:
input_variables = []
_LOGGER.info(
"The `input_variables` parameter is empty. Only the `response`"
" column and `baseline_model_response` columns are used for"
" computing this model-based metric."
)
input_variables = list(
set(input_variables + ["response", "baseline_model_response"])
)
instruction = instruction or self.get_default_pairwise_instruction()
evaluation_steps = (
evaluation_steps or self.get_default_pairwise_evaluation_steps()
)
super().__init__(
input_variables=input_variables,
criteria=criteria,
rating_rubric=rating_rubric,
instruction=instruction,
metric_definition=metric_definition,
evaluation_steps=evaluation_steps,
few_shot_examples=few_shot_examples,
)
def get_default_pairwise_instruction(self) -> str:
"""Returns the default instruction for the metric prompt template."""
return (
"You are an expert evaluator. Your task is to evaluate the quality of"
" the responses generated by two AI models. We will provide you with"
" the user input and a pair of AI-generated responses (Response A and"
" Response B).\nYou should first read the user input carefully for"
" analyzing the task, and then evaluate the quality of the responses"
" based on based on the Criteria provided in the Evaluation section"
" below.\nYou will first judge responses individually, following the"
" Rating Rubric and Evaluation Steps. Then you will give step by step"
" explanations for your judgement, compare results to declare the"
" winner based on the Rating Rubric and Evaluation Steps."
)
def get_default_pairwise_evaluation_steps(self) -> Dict[str, str]:
"""Returns the default evaluation steps for the metric prompt template."""
return {
"Step 1": "Analyze Response A based on all the Criteria.",
"Step 2": "Analyze Response B based on all the Criteria.",
"Step 3": (
"Compare the overall performance of Response A and Response B based"
" on your analyses and assessment."
),
"Step 4": (
'Output your preference of "A", "SAME" or "B" to the'
" pairwise_choice field according to the Rating Rubrics."
),
"Step 5": "Output your assessment reasoning in the explanation field",
}
def __str__(self):
"""Serializes the pairwise metric prompt template to a string."""
metric_prompt_template_str = [
"# Instruction",
f"{self._instruction}",
_NEWLINE,
"# Evaluation",
]
if self._metric_definition:
metric_prompt_template_str.extend(
[
"## Metric Definition",
f"{self._metric_definition}\n",
]
)
metric_prompt_template_str.extend(
[
"## Criteria",
f"{serialize_dict_in_order(self._criteria)}\n",
"## Rating Rubric",
f"{serialize_dict_in_order(self._rating_rubric)}\n",
]
)
if self._evaluation_steps:
metric_prompt_template_str.extend(
[
"## Evaluation Steps",
f"{serialize_dict_in_order(self._evaluation_steps)}\n",
]
)
if self._few_shot_examples:
metric_prompt_template_str.extend(
[
"## Evaluation Examples",
f"{_NEWLINE.join(self._few_shot_examples)}\n",
]
)
metric_prompt_template_str.extend(
["\n# User Inputs and AI-generated Responses", "## User Inputs"]
)
for input_variable in self._input_variables:
if input_variable in ["response", "baseline_model_response"]:
continue
metric_prompt_template_str.extend(
[
f"### {input_variable}",
f"{{{input_variable}}}\n",
]
)
metric_prompt_template_str.extend(
[
"\n## AI-generated Responses",
"### Response A",
"{baseline_model_response}\n",
"### Response B",
"{response}",
]
)
return _NEWLINE.join(metric_prompt_template_str)
def __repr__(self):
return (
f"PairwiseMetricPromptTemplate(prompt_data={self.prompt_data},"
f" variables={self.variables})"
)

View File

@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Example metric prompt templates for model-based evaluation."""
from typing import List
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation.metrics import (
_default_templates,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
class MetricPromptTemplateExamples:
"""Examples of metric prompt templates for model-based evaluation."""
_PROMPT_TEMPLATE_MAP = {
constants.Metric.COHERENCE: _default_templates.COHERENCE_PROMPT_TEMPLATE,
constants.Metric.FLUENCY: _default_templates.FLUENCY_PROMPT_TEMPLATE,
constants.Metric.SAFETY: _default_templates.SAFETY_PROMPT_TEMPLATE,
constants.Metric.GROUNDEDNESS: (
_default_templates.GROUNDEDNESS_PROMPT_TEMPLATE
),
constants.Metric.INSTRUCTION_FOLLOWING: (
_default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
),
constants.Metric.VERBOSITY: _default_templates.VERBOSITY_PROMPT_TEMPLATE,
constants.Metric.TEXT_QUALITY: (
_default_templates.TEXT_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.SUMMARIZATION_QUALITY: (
_default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.QUESTION_ANSWERING_QUALITY: (
_default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.MULTI_TURN_CHAT_QUALITY: (
_default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.MULTI_TURN_SAFETY: (
_default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_COHERENCE: (
_default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_FLUENCY: (
_default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_SAFETY: (
_default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_GROUNDEDNESS: (
_default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING: (
_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_VERBOSITY: (
_default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_TEXT_QUALITY: (
_default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY: (
_default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY: (
_default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY: (
_default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE
),
constants.Metric.PAIRWISE_MULTI_TURN_SAFETY: (
_default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE
),
}
@classmethod
def get_prompt_template(cls, metric_name: str) -> str:
"""Returns the prompt template for the given metric name."""
return cls._PROMPT_TEMPLATE_MAP[metric_name]
@classmethod
def list_example_metric_names(cls) -> List[str]:
"""Returns a list of all metric prompt templates."""
return list(cls._PROMPT_TEMPLATE_MAP.keys())
class Pointwise:
"""Example PointwiseMetric instances."""
FLUENCY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.FLUENCY,
metric_prompt_template=_default_templates.FLUENCY_PROMPT_TEMPLATE,
)
COHERENCE = pointwise_metric.PointwiseMetric(
metric=constants.Metric.COHERENCE,
metric_prompt_template=_default_templates.COHERENCE_PROMPT_TEMPLATE,
)
SAFETY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.SAFETY,
metric_prompt_template=_default_templates.SAFETY_PROMPT_TEMPLATE,
)
GROUNDEDNESS = pointwise_metric.PointwiseMetric(
metric=constants.Metric.GROUNDEDNESS,
metric_prompt_template=_default_templates.GROUNDEDNESS_PROMPT_TEMPLATE,
)
INSTRUCTION_FOLLOWING = pointwise_metric.PointwiseMetric(
metric=constants.Metric.INSTRUCTION_FOLLOWING,
metric_prompt_template=_default_templates.INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE,
)
VERBOSITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.VERBOSITY,
metric_prompt_template=_default_templates.VERBOSITY_PROMPT_TEMPLATE,
)
TEXT_QUALITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.TEXT_QUALITY,
metric_prompt_template=_default_templates.TEXT_QUALITY_PROMPT_TEMPLATE,
)
SUMMARIZATION_QUALITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.SUMMARIZATION_QUALITY,
metric_prompt_template=_default_templates.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
)
QUESTION_ANSWERING_QUALITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.QUESTION_ANSWERING_QUALITY,
metric_prompt_template=_default_templates.QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE,
)
MULTI_TURN_CHAT_QUALITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.MULTI_TURN_CHAT_QUALITY,
metric_prompt_template=_default_templates.MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE,
)
MULTI_TURN_SAFETY_QUALITY = pointwise_metric.PointwiseMetric(
metric=constants.Metric.MULTI_TURN_SAFETY,
metric_prompt_template=_default_templates.MULTI_TURN_SAFETY_PROMPT_TEMPLATE,
)
class Pairwise:
"""Example PairwiseMetric instances."""
FLUENCY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_FLUENCY,
metric_prompt_template=_default_templates.PAIRWISE_FLUENCY_PROMPT_TEMPLATE,
)
COHERENCE = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_COHERENCE,
metric_prompt_template=_default_templates.PAIRWISE_COHERENCE_PROMPT_TEMPLATE,
)
SAFETY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_SAFETY,
metric_prompt_template=_default_templates.PAIRWISE_SAFETY_PROMPT_TEMPLATE,
)
GROUNDEDNESS = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_GROUNDEDNESS,
metric_prompt_template=_default_templates.PAIRWISE_GROUNDEDNESS_PROMPT_TEMPLATE,
)
INSTRUCTION_FOLLOWING = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_INSTRUCTION_FOLLOWING,
metric_prompt_template=_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_PROMPT_TEMPLATE,
)
VERBOSITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_VERBOSITY,
metric_prompt_template=_default_templates.PAIRWISE_VERBOSITY_PROMPT_TEMPLATE,
)
TEXT_QUALITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_TEXT_QUALITY,
metric_prompt_template=_default_templates.PAIRWISE_TEXT_QUALITY_PROMPT_TEMPLATE,
)
SUMMARIZATION_QUALITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_SUMMARIZATION_QUALITY,
metric_prompt_template=_default_templates.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
)
QUESTION_ANSWERING_QUALITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_QUESTION_ANSWERING_QUALITY,
metric_prompt_template=_default_templates.PAIRWISE_QUESTION_ANSWERING_QUALITY_PROMPT_TEMPLATE,
)
MULTI_TURN_CHAT_QUALITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_MULTI_TURN_CHAT_QUALITY,
metric_prompt_template=_default_templates.PAIRWISE_MULTI_TURN_CHAT_QUALITY_PROMPT_TEMPLATE,
)
MULTI_TURN_SAFETY_QUALITY = pairwise_metric.PairwiseMetric(
metric=constants.Metric.PAIRWISE_MULTI_TURN_SAFETY,
metric_prompt_template=_default_templates.PAIRWISE_MULTI_TURN_SAFETY_PROMPT_TEMPLATE,
)

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Model-based Pairwise Metric."""
from typing import Callable, Optional, Union
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview import generative_models
from vertexai.preview.evaluation.metrics import _base
from vertexai.preview.evaluation.metrics import (
custom_output_config as custom_output_config_class,
)
from vertexai.preview.evaluation.metrics import (
metric_prompt_template as metric_prompt_template_base,
)
class PairwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-access
"""A Model-based Pairwise Metric.
A model-based evaluation metric that compares two generative models' responses
side-by-side, and allows users to A/B test their generative models to
determine which model is performing better.
For more details on when to use pairwise metrics, see
[Evaluation methods and
metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#pointwise_versus_pairwise).
Result Details:
* In `EvalResult.summary_metrics`, win rates for both the baseline and
candidate model are computed. The win rate is computed as proportion of
wins of one model's responses to total attempts as a decimal value
between 0 and 1.
* In `EvalResult.metrics_table`, a pairwise metric produces two
evaluation results per dataset row:
* `pairwise_choice`: The choice shows whether the candidate model or
the baseline model performs better, or if they are equally good.
* `explanation`: The rationale behind each verdict using
chain-of-thought reasoning. The explanation helps users scrutinize
the judgment and builds appropriate trust in the decisions.
See [documentation
page](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#understand-results)
for more details on understanding the metric results.
Usage Examples:
```
baseline_model = GenerativeModel("gemini-1.0-pro")
candidate_model = GenerativeModel("gemini-1.5-pro")
pairwise_groundedness = PairwiseMetric(
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template(
"pairwise_groundedness"
),
baseline_model=baseline_model,
)
eval_dataset = pd.DataFrame({
"prompt" : [...],
})
pairwise_task = EvalTask(
dataset=eval_dataset,
metrics=[pairwise_groundedness],
experiment="my-pairwise-experiment",
)
pairwise_result = pairwise_task.evaluate(
model=candidate_model,
experiment_run_name="gemini-pairwise-eval-run",
)
```
"""
def __init__(
self,
*,
metric: str,
metric_prompt_template: Union[
metric_prompt_template_base.PairwiseMetricPromptTemplate, str
],
baseline_model: Optional[
Union[generative_models.GenerativeModel, Callable[[str], str]]
] = None,
system_instruction: Optional[str] = None,
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
custom_output_config: Optional[
custom_output_config_class.CustomOutputConfig
] = None,
):
"""Initializes a pairwise evaluation metric.
Args:
metric: The pairwise evaluation metric name.
metric_prompt_template: Pairwise metric prompt template for performing
the pairwise model-based evaluation. A freeform string is also accepted.
baseline_model: The baseline model for side-by-side comparison. If not
specified, `baseline_model_response` column is required in the dataset
to perform bring-your-own-response(BYOR) evaluation.
system_instruction: The system instruction for the evaluation.
autorater_config: The config for judge model.
custom_output_config: Config for custom output from the judge model.
"""
super().__init__(
metric_prompt_template=metric_prompt_template,
metric=metric,
system_instruction=system_instruction,
autorater_config=autorater_config,
custom_output_config=custom_output_config,
)
self._baseline_model = baseline_model
@property
def baseline_model(
self,
) -> Union[generative_models.GenerativeModel, Callable[[str], str]]:
return self._baseline_model

View File

@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Model-based Pointwise Metric."""
from typing import Optional, Union
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview.evaluation.metrics import _base
from vertexai.preview.evaluation.metrics import (
custom_output_config as custom_output_config_class,
)
from vertexai.preview.evaluation.metrics import (
metric_prompt_template as metric_prompt_template_base,
)
class PointwiseMetric(_base._ModelBasedMetric): # pylint: disable=protected-access
"""A Model-based Pointwise Metric.
A model-based evaluation metric that evaluate a single generative model's
response.
For more details on when to use model-based pointwise metrics, see
[Evaluation methods and metrics](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval).
Usage Examples:
```
candidate_model = GenerativeModel("gemini-1.5-pro")
eval_dataset = pd.DataFrame({
"prompt" : [...],
})
fluency_metric = PointwiseMetric(
metric="fluency",
metric_prompt_template=MetricPromptTemplateExamples.get_prompt_template('fluency'),
)
pointwise_eval_task = EvalTask(
dataset=eval_dataset,
metrics=[
fluency_metric,
MetricPromptTemplateExamples.Pointwise.GROUNDEDNESS,
],
)
pointwise_result = pointwise_eval_task.evaluate(
model=candidate_model,
)
```
"""
def __init__(
self,
*,
metric: str,
metric_prompt_template: Union[
metric_prompt_template_base.PointwiseMetricPromptTemplate, str
],
system_instruction: Optional[str] = None,
autorater_config: Optional[gapic_eval_service_types.AutoraterConfig] = None,
custom_output_config: Optional[
custom_output_config_class.CustomOutputConfig
] = None,
):
"""Initializes a pointwise evaluation metric.
Args:
metric: The pointwise evaluation metric name.
metric_prompt_template: Pointwise metric prompt template for performing
the model-based evaluation. A freeform string is also accepted.
system_instruction: The system instruction for the evaluation.
autorater_config: The config for judge model.
custom_output_config: Config for custom output from the judge model.
"""
super().__init__(
metric_prompt_template=metric_prompt_template,
metric=metric,
system_instruction=system_instruction,
autorater_config=autorater_config,
custom_output_config=custom_output_config,
)

View File

@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai.preview.evaluation import utils
from vertexai.preview.evaluation.metrics import (
_base as metrics_base,
)
from vertexai.preview.evaluation.metrics import (
_default_templates,
)
from vertexai.preview.evaluation.metrics import (
custom_output_config,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
from vertexai.preview.evaluation.metrics import (
rubric_based_metric,
)
AutoraterConfig = gapic_eval_service_types.AutoraterConfig
_POINTWISE_OUTPUT_CONFIG = custom_output_config.CustomOutputConfig(
return_raw_output=True,
parsing_fn=utils.parse_pointwise_rubric_result,
)
_PAIRWISE_OUTPUT_CONFIG = custom_output_config.CustomOutputConfig(
return_raw_output=True,
parsing_fn=utils.parse_pairwise_rubric_result,
)
_PAIRWISE_AUTORATER_CONFIG = AutoraterConfig(
sampling_count=1,
)
class PredefinedRubricMetrics:
"""Predefined rubric-based metrics."""
class Pointwise:
"""Pointwise rubric-based metrics."""
INSTRUCTION_FOLLOWING = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_GENERATION_PROMPT_TEMPLATE,
),
critique_metric=pointwise_metric.PointwiseMetric(
metric="rb_instruction_following",
metric_prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
),
)
MULTIMODAL_UNDERSTANDING = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_GENERATION_PROMPT_TEMPLATE
),
critique_metric=pointwise_metric.PointwiseMetric(
metric="rb_multimodal_understanding",
metric_prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
),
)
TEXT_QUALITY = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_GENERATION_PROMPT_TEMPLATE
),
critique_metric=pointwise_metric.PointwiseMetric(
metric="rb_text_quality",
metric_prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_POINTWISE_OUTPUT_CONFIG,
),
)
class Pairwise:
"""Pairwise rubric-based metrics."""
INSTRUCTION_FOLLOWING = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.INSTRUCTION_FOLLOWING_RUBRIC_GENERATION_PROMPT_TEMPLATE,
),
critique_metric=pairwise_metric.PairwiseMetric(
metric="pairwise_rb_instruction_following",
metric_prompt_template=_default_templates.PAIRWISE_INSTRUCTION_FOLLOWING_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
),
)
MULTIMODAL_UNDERSTANDING = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.MULTIMODAL_UNDERSTANDING_RUBRIC_GENERATION_PROMPT_TEMPLATE
),
critique_metric=pairwise_metric.PairwiseMetric(
metric="pairwise_rb_multimodal_understanding",
metric_prompt_template=_default_templates.PAIRWISE_MULTIMODAL_UNDERSTANDING_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
),
)
TEXT_QUALITY = rubric_based_metric.RubricBasedMetric(
generation_config=metrics_base.RubricGenerationConfig(
prompt_template=_default_templates.TEXT_QUALITY_RUBRIC_GENERATION_PROMPT_TEMPLATE
),
critique_metric=pairwise_metric.PairwiseMetric(
metric="pairwise_rb_text_quality",
metric_prompt_template=_default_templates.PAIRWISE_TEXT_QUALITY_RUBRIC_CRITIQUE_TEMPLATE,
custom_output_config=_PAIRWISE_OUTPUT_CONFIG,
autorater_config=_PAIRWISE_AUTORATER_CONFIG,
),
)

View File

@@ -0,0 +1,104 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
from typing import Union, TYPE_CHECKING
from google.cloud.aiplatform import base
from vertexai import generative_models
from vertexai.preview.evaluation import _pre_eval_utils
from vertexai.preview.evaluation import constants
from vertexai.preview.evaluation import utils
from vertexai.preview.evaluation.metrics import (
_base as metrics_base,
)
from vertexai.preview.evaluation.metrics import pairwise_metric
from vertexai.preview.evaluation.metrics import pointwise_metric
if TYPE_CHECKING:
import pandas as pd
_DEFAULT_MODEL_NAME = "gemini-2.0-flash-001"
_LOGGER = base.Logger(__name__)
class RubricBasedMetric(metrics_base._Metric):
"""Config for Rubric-Based Eval."""
def __init__(
self,
*,
generation_config: metrics_base.RubricGenerationConfig,
critique_metric: Union[
pointwise_metric.PointwiseMetric, pairwise_metric.PairwiseMetric
]
):
"""Initializes RubricBasedMetric.
Args:
generation_config: Config for rubric generation.
critique_metric: Pointwise/pairwise metric for rubric critique.
"""
super().__init__(metric=critique_metric._metric)
self.generation_config = generation_config
self.critique_metric = critique_metric
def generate_rubrics(
self,
eval_dataset: "pd.Dataframe",
) -> "pd.DataFrame":
"""Generates rubrics for given eval dataset."""
if not self.generation_config.model:
model = generative_models.GenerativeModel(model_name=_DEFAULT_MODEL_NAME)
else:
model = self.generation_config.model
if constants.Dataset.RUBRICS_COLUMN in eval_dataset.columns:
_LOGGER.warning(
"Rubrics column already exists in the dataset. Skipping rubric"
" generation."
)
return eval_dataset
responses = _pre_eval_utils._generate_responses_from_gemini_model(
model,
eval_dataset,
self.generation_config.prompt_template,
)
if self.generation_config.parsing_fn:
parsing_fn = self.generation_config.parsing_fn
else:
parsing_fn = utils.parse_rubrics
dataset_with_rubrics = eval_dataset.copy()
aggregated = collections.defaultdict(list)
for idx, response in enumerate(responses):
result = parsing_fn(response)
if isinstance(result, dict):
questions = result.pop("questions", None)
if questions is not None:
aggregated[constants.Dataset.RUBRICS_COLUMN].append(
(idx, questions)
)
for key, value in result.items():
aggregated[key].append((idx, value))
else:
aggregated[constants.Dataset.RUBRICS_COLUMN].append((idx, result))
for key, values in aggregated.items():
dataset_with_rubrics[key] = None
dataset_with_rubrics[key] = dataset_with_rubrics[key].astype(object)
for idx, value in values:
dataset_with_rubrics.at[idx, key] = value
return dataset_with_rubrics

View File

@@ -0,0 +1,146 @@
"""Utility functions for multimodal evaluation."""
import logging
import re
from typing import Any, Dict, Union, List, Set
from google.cloud.aiplatform import base
from google.cloud.aiplatform_v1beta1.types import content
from google.cloud.aiplatform_v1beta1.types import (
evaluation_service as gapic_eval_service_types,
)
from vertexai import generative_models
from vertexai.preview.evaluation import (
prompt_template as prompt_template_base,
)
from google.protobuf import json_format
ContentMap = gapic_eval_service_types.ContentMap
Content = content.Content
Part = content.Part
_CONTENTS_DETECTOR = "contents {"
_PARTS_DETECTOR = "parts {"
_LOGGER = base.Logger(__name__)
def _string_to_content_list(input_str: str) -> ContentMap.Contents:
"""Converts a string to a list if possible, otherwise returns None."""
try:
return json_format.Parse(
input_str,
ContentMap.Contents.pb(ContentMap.Contents()),
)
except json_format.ParseError as e:
if _CONTENTS_DETECTOR in input_str and _PARTS_DETECTOR in input_str:
logging.warning(
"Failed to parse %s to ContentMap.Contents: %s", input_str, e
)
return None
def _is_multimodal_response(response: str) -> bool:
"""Checks if the model response contains multimodal input."""
content_list = _string_to_content_list(response)
if content_list is None:
if _CONTENTS_DETECTOR in response and _PARTS_DETECTOR in response:
logging.warning(
"Response contains multimodal input: %s. Please check whether"
" the response format conforms to ContentMap type.",
response,
)
return False
else:
return True
def is_multimodal_instance(
model_based_metric_instance_input: Dict[str, str],
) -> bool:
"""Checks if the evaluation instance contains multimodal input."""
for placeholder in model_based_metric_instance_input:
if _is_multimodal_response(model_based_metric_instance_input[placeholder]):
return True
return False
def convert_multimodal_response_to_content_map(
model_based_metric_instance_input: Dict[str, str],
) -> ContentMap:
"""Converts a multimodal model response to a ContentMap."""
content_map = ContentMap()
for placeholder in model_based_metric_instance_input.keys():
content_list = _string_to_content_list(
model_based_metric_instance_input[placeholder]
)
if content_list is None:
content_map.values[placeholder] = ContentMap.Contents(
contents=[
Content(
parts=[
Part(text=model_based_metric_instance_input[placeholder])
]
)
]
)
else:
content_map.values[placeholder] = content_list
return content_map
def _split_metric_prompt_template(
metric_prompt_template: str,
placeholders: Set[str],
) -> List[str]:
"""Splits the metric prompt template into a list of strings by placeholders."""
placeholders_with_brackets = [
re.escape("{" + placeholder + "}") for placeholder in placeholders
]
pattern = "|".join(f"({placeholder})" for placeholder in placeholders_with_brackets)
split_metric_prompt_template = re.split(pattern, metric_prompt_template)
return [element for element in split_metric_prompt_template if element]
def _assemble_multi_modal_prompt(
metric_prompt_template: Union[prompt_template_base.PromptTemplate, str],
data_row: Dict[str, Any],
row_index: int,
placeholders: Set[str],
) -> List[Union[str, generative_models.Part]]:
"""Fills in the split metric prompt template elements with multimodal data to be sent to the model."""
split_template_elements = _split_metric_prompt_template(
str(metric_prompt_template), placeholders
)
part_inputs = []
for element in split_template_elements:
placeholder = element.replace("{", "").replace("}", "")
if placeholder in data_row.keys():
content_list = _string_to_content_list(data_row[placeholder])
if content_list is None:
part_inputs.append(data_row[placeholder])
else:
for content_inp in content_list.contents:
for part in content_inp.parts:
if part.HasField("text"):
part_inputs.append(part.text)
elif part.HasField("file_data"):
part_inputs.append(
generative_models.Part.from_uri(
part.file_data.file_uri,
mime_type=part.file_data.mime_type,
)
)
else:
_LOGGER.warning(
"The multimodal input you provided "
f"at row {row_index} "
"contains part types that are not "
"yet supported. Currently supported"
"part types are text and file_data"
)
else:
part_inputs.append(element)
return part_inputs

View File

@@ -0,0 +1,251 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Python functions which run only within a Jupyter or Colab notebook."""
import random
import string
import sys
from typing import List, Optional, Tuple
from vertexai.preview.evaluation import _base as eval_base
from vertexai.preview.evaluation import constants
# pylint: disable=g-import-not-at-top
try:
import pandas as pd
except ImportError:
pandas = None
_MARKDOWN_H2 = "##"
_MARKDOWN_H3 = "###"
_DEFAULT_COLUMNS_TO_DISPLAY = [
constants.Dataset.MODEL_RESPONSE_COLUMN,
constants.Dataset.BASELINE_MODEL_RESPONSE_COLUMN,
constants.Dataset.PROMPT_COLUMN,
constants.MetricResult.ROW_COUNT_KEY,
]
_DEFAULT_RADAR_RANGE = (0, 5)
def _get_ipython_shell_name() -> str:
if "IPython" in sys.modules:
# pylint: disable=g-import-not-at-top, g-importing-member
from IPython import get_ipython
return get_ipython().__class__.__name__
return ""
def is_ipython_available() -> bool:
return _get_ipython_shell_name()
def _filter_df(
df: pd.DataFrame, substrings: Optional[List[str]] = None
) -> pd.DataFrame:
"""Filters a DataFrame to include only columns containing the given substrings."""
if substrings is None:
return df
return df.copy().filter(
[
column_name
for column_name in df.columns
if any(substring in column_name for substring in substrings)
]
)
def display_eval_result(
eval_result: "eval_base.EvalResult",
title: Optional[str] = None,
metrics: Optional[List[str]] = None,
) -> None:
"""Displays evaluation results in a notebook using IPython.display.
Args:
eval_result: An object containing evaluation results with
`summary_metrics` and `metrics_table` attributes.
title: A string title to display above the results.
metrics: A list of metric name substrings to filter displayed columns. If
provided, only metrics whose names contain any of these strings will be
displayed.
"""
if not is_ipython_available():
return
# pylint: disable=g-import-not-at-top, g-importing-member
from IPython.display import display
from IPython.display import Markdown
summary_metrics, metrics_table = (
eval_result.summary_metrics,
eval_result.metrics_table,
)
summary_metrics_df = pd.DataFrame.from_dict(summary_metrics, orient="index").T
if metrics:
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
summary_metrics_df = _filter_df(summary_metrics_df, columns_to_keep)
metrics_table = _filter_df(metrics_table, columns_to_keep)
# Display the title in Markdown.
if title:
display(Markdown(f"{_MARKDOWN_H2} {title}"))
# Display the summary metrics.
display(Markdown(f"{_MARKDOWN_H3} Summary Metrics"))
display(summary_metrics_df)
# Display the metrics table.
display(Markdown(f"{_MARKDOWN_H3} Row-based Metrics"))
display(metrics_table)
def display_explanations(
eval_result: "eval_base.EvalResult",
num: int = 1,
metrics: Optional[List[str]] = None,
) -> None:
"""Displays the explanations in a notebook using IPython.display.
Args:
eval_result: An object containing evaluation results. It is expected to
have attributes `summary_metrics` and `metrics_table`.
num: The number of row samples to display. Defaults to 1. If the number of
rows is less than `num`, all rows will be displayed.
metrics: A list of metric name substrings to filter displayed columns. If
provided, only metrics whose names contain any of these strings will be
displayed.
"""
if not is_ipython_available():
return
# pylint: disable=g-import-not-at-top, g-importing-member
from IPython.display import display
from IPython.display import HTML
style = "white-space: pre-wrap; width: 1500px; overflow-x: auto;"
metrics_table = eval_result.metrics_table
if num < 1:
raise ValueError("Num must be greater than 0.")
num = min(num, len(metrics_table))
df = metrics_table.sample(n=num)
if metrics:
columns_to_keep = metrics + _DEFAULT_COLUMNS_TO_DISPLAY
df = _filter_df(df, columns_to_keep)
for _, row in df.iterrows():
for col in df.columns:
display(HTML(f"<div style='{style}'><h4>{col}:</h4>{row[col]}</div>"))
display(HTML("<hr>"))
def display_radar_plot(
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
metrics: List[str],
radar_range: Tuple[float, float] = _DEFAULT_RADAR_RANGE,
) -> None:
"""Plots a radar plot comparing evaluation results.
Args:
eval_results_with_title: List of (title, eval_result) tuples.
metrics: A list of metrics whose mean values will be plotted.
radar_range: Range of the radar plot axes.
"""
# pylint: disable=g-import-not-at-top
try:
import plotly.graph_objects as go
except ImportError as exc:
raise ImportError(
'`plotly` is not installed. Please install using "!pip install plotly"'
) from exc
fig = go.Figure()
for title, eval_result in eval_results_with_title:
summary_metrics = eval_result.summary_metrics
if metrics:
summary_metrics = {
key.replace("/mean", ""): summary_metrics[key]
for key in summary_metrics
if any(selected_metric + "/mean" in key for selected_metric in metrics)
}
fig.add_trace(
go.Scatterpolar(
r=list(summary_metrics.values()),
theta=list(summary_metrics.keys()),
fill="toself",
name=title,
)
)
fig.update_layout(
polar=dict(radialaxis=dict(visible=True, range=radar_range)),
showlegend=True,
)
fig.show()
def display_bar_plot(
eval_results_with_title: List[Tuple[str, "eval_base.EvalResult"]],
metrics: List[str],
) -> None:
"""Plots a bar plot comparing evaluation results.
Args:
eval_results_with_title: List of (title, eval_result) tuples.
metrics: A list of metrics whose mean values will be plotted.
"""
# pylint: disable=g-import-not-at-top
try:
import plotly.graph_objects as go
except ImportError as exc:
raise ImportError(
'`plotly` is not installed. Please install using "!pip install plotly"'
) from exc
data = []
for title, eval_result in eval_results_with_title:
summary_metrics = eval_result.summary_metrics
mean_summary_metrics = [f"{metric}/mean" for metric in metrics]
updated_summary_metrics = []
if metrics:
for k, v in summary_metrics.items():
if k in mean_summary_metrics:
updated_summary_metrics.append((k, v))
summary_metrics = dict(updated_summary_metrics)
data.append(
go.Bar(
x=list(summary_metrics.keys()),
y=list(summary_metrics.values()),
name=title,
)
)
fig = go.Figure(data=data)
fig.update_layout(barmode="group", showlegend=True)
fig.show()
def generate_uuid(length: int = 8) -> str:
"""Generates a uuid of a specified length (default=8)."""
return "".join(random.choices(string.ascii_lowercase + string.digits, k=length))

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Prompt template for creating prompts with variables."""
import re
from typing import Set
_VARIABLE_NAME_REGEX = r"\{([_a-zA-Z][_a-zA-Z0-9]*)\}"
class PromptTemplate:
"""A prompt template for creating prompts with variables.
The `PromptTemplate` class allows users to define a template string with
variables represented in curly braces `{variable}`. The variable
names cannot contain spaces and must start with a letter or underscore,
followed by letters, digits, or underscore. These variables can be
replaced with specific values using the `assemble` method, providing
flexibility in generating dynamic prompts.
Usage:
```
template_str = "Hello, {name}! Today is {day}. How are you?"
prompt_template = PromptTemplate(template_str)
completed_prompt = prompt_template.assemble(name="John", day="Monday")
print(completed_prompt)
```
"""
def __init__(self, template: str):
"""Initializes the PromptTemplate with a given template.
Args:
template: The template string with variables. Variables should be
represented in curly braces `{variable}`.
"""
self.template = str(template)
self.variables = self._get_variables()
def _get_variables(self) -> Set[str]:
"""Extracts and return a set of variable names from the template."""
return set(re.findall(_VARIABLE_NAME_REGEX, self.template))
def assemble(self, **kwargs) -> "PromptTemplate":
"""Replaces only the provided variables in the template with specific values.
Args:
**kwargs: Keyword arguments where keys are placeholder names and values
are the replacements.
Returns:
A new PromptTemplate instance with the updated template string.
"""
assembled_string = self.template
for variable_name, value in kwargs.items():
if variable_name not in self.variables:
raise ValueError(
f"Invalid variable name '{variable_name}'. "
f"Valid variables are: {self.variables}"
)
placeholder = "{" + variable_name + "}"
assembled_string = assembled_string.replace(placeholder, str(value))
return PromptTemplate(assembled_string)
def __str__(self) -> str:
"""Returns the template string."""
return self.template
def __repr__(self) -> str:
"""Returns a string representation of the PromptTemplate."""
return f"PromptTemplate('{self.template}')"

View File

@@ -0,0 +1,640 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility functions for evaluation."""
import functools
import io
import json
import os
import re
import sys
import tempfile
import threading
import time
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from google.cloud import bigquery
from google.cloud import storage
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import _ipython_utils
from google.cloud.aiplatform_v1beta1.services import (
evaluation_service as gapic_evaluation_services,
)
from vertexai.evaluation import _base as eval_base
from vertexai.evaluation.metrics import _base as metrics_base
from vertexai.evaluation.metrics import (
metric_prompt_template as metric_prompt_template_base,
)
if TYPE_CHECKING:
import pandas as pd
_BQ_PREFIX = "bq://"
_GCS_PREFIX = "gs://"
_LOGGER = base.Logger(__name__)
_QUESTION_REGEX = re.compile(r"Question:(.*?)Verdict:", re.DOTALL)
_VERDICT_REGEX = re.compile("Verdict:(.*)")
_QUESTION_BLOCK_REGEX = re.compile("<question>(.*?)</question>", re.DOTALL)
_RESPONSE_A_REGEX = re.compile(
r"\[\[Response A Answers:\]\](.*?)\[\[Rubric Score:", re.DOTALL
)
_RESPONSE_B_REGEX = re.compile(
r"\[\[Response B Answers:\]\](.*?)\[\[Rubric Score:", re.DOTALL
)
_SXS_RATING_REGEX = re.compile(r"\[\[(SxSRating:[AB<>=]+)\]\]", re.DOTALL)
RATING_TO_VERDICT = {
"B>>A": "Candidate response is better than the baseline response.",
"A<<B": "Candiate response is better than the baseline response.",
"B>A": "Candidate response is slightly better than the baseline response.",
"A<B": "Candidate response is slightly better than the baseline response.",
"A=B": "Both responses are equally good.",
"B=A": "Both responses are equally good.",
"A>B": "Baseline response is slightly better than the candidate response.",
"B<A": "Baseline response is slightly better than the candidate response.",
"B<<A": "Baseline response is better than the candidate response.",
"A>>B": "Baseline response is better than the candidate response.",
}
_RATING_TO_SCORE = {
"B>>A": 1,
"A<<B": 1,
"B>A": 0.5,
"A<B": 0.5,
"A=B": 0,
"B=A": 0,
"A>B": -0.5,
"B<A": -0.5,
"B<<A": -1,
"A>>B": -1,
}
class _EvaluationServiceClientWithOverride(utils.ClientWithOverride):
_is_temporary = False
_default_version = compat.V1
_version_map = (
(
compat.V1,
gapic_evaluation_services.EvaluationServiceClient,
),
)
class RateLimiter:
"""Helper class for rate-limiting requests to Vertex AI to improve QoS.
Attributes:
seconds_per_event: The time interval (in seconds) between events to
maintain the desired rate.
last: The timestamp of the last event.
_lock: A lock to ensure thread safety.
"""
def __init__(self, rate: Optional[float] = None):
"""Initializes the rate limiter.
A simple rate limiter for controlling the frequency of API calls. This class
implements a token bucket algorithm to limit the rate at which events
can occur. It's designed for cases where the batch size (number of events
per call) is always 1 for traffic shaping and rate limiting.
Args:
rate: The number of queries allowed per second.
Raises:
ValueError: If the rate is not positive.
"""
if not rate or rate <= 0:
raise ValueError("Rate must be a positive number")
self.seconds_per_event = 1.0 / rate
self.last = time.time() - self.seconds_per_event
self._lock = threading.Lock()
def _admit(self) -> float:
"""Checks if an event can be admitted or calculates the remaining delay."""
now = time.time()
time_since_last = now - self.last
if time_since_last >= self.seconds_per_event:
self.last = now
return 0
else:
return self.seconds_per_event - time_since_last
def sleep_and_advance(self):
"""Blocks the current thread until the next event can be admitted."""
with self._lock:
delay = self._admit()
if delay > 0:
time.sleep(delay)
self.last = time.time()
def rate_limit(rate: Optional[float] = None) -> Callable[[Any], Any]:
"""Decorator version of rate limiter."""
def _rate_limit(method):
limiter = RateLimiter(rate)
@functools.wraps(method)
def wrapper(*args, **kwargs):
limiter.sleep_and_advance()
return method(*args, **kwargs)
return wrapper
return _rate_limit
def create_evaluation_service_client(
api_base_path_override: Optional[str] = None,
) -> _EvaluationServiceClientWithOverride:
"""Creates a client for the evaluation service.
Args:
api_base_path_override: Optional. Override default api base path.
Returns:
Instantiated Vertex AI EvaluationServiceClient with optional
overrides.
"""
return initializer.global_config.create_client(
client_class=_EvaluationServiceClientWithOverride,
location_override=initializer.global_config.location,
api_base_path_override=api_base_path_override,
)
def load_dataset(
source: Union[str, "pd.DataFrame", Dict[str, Any]],
) -> "pd.DataFrame":
"""Loads dataset from various sources into a DataFrame.
Args:
source: The dataset source. Supports the following dataset formats:
* pandas.DataFrame: Used directly for evaluation.
* Dict: Converted to a pandas DataFrame before evaluation.
* str: Interpreted as a file path or URI. Supported formats include:
* Local JSONL or CSV files: Loaded from the local filesystem.
* GCS JSONL or CSV files: Loaded from Google Cloud Storage (e.g.,
'gs://bucket/data.csv').
* BigQuery table URI: Loaded from Google Cloud
BigQuery (e.g., 'bq://project-id.dataset.table_name').
Returns:
The dataset in pandas DataFrame format.
"""
try:
import pandas as pd
except ImportError:
raise ImportError(
'Pandas is not installed. Please install the SDK using "pip install'
' google-cloud-aiplatform[evaluation]"'
)
if "google.colab" in sys.modules:
from google.colab import sheets
if isinstance(source, sheets.InteractiveSheet):
return source.as_df().copy()
if isinstance(source, pd.DataFrame):
return source.copy()
elif isinstance(source, dict):
return pd.DataFrame(source)
elif isinstance(source, str):
if source.startswith(_BQ_PREFIX):
return _load_bigquery(source[len(_BQ_PREFIX) :])
_, extension = os.path.splitext(source)
file_type = extension.lower()[1:]
if file_type == "jsonl":
return _load_jsonl(source)
elif file_type == "csv":
return _load_csv(source)
else:
raise ValueError(
f"Unsupported file type: {file_type} from {source}. Please"
" provide a valid GCS path with `jsonl` or `csv` suffix or a valid"
" BigQuery table URI."
)
else:
raise TypeError(
"Unsupported dataset type. Must be a `pd.DataFrame`, Python dictionary,"
" valid GCS path with `jsonl` or `csv` suffix or a valid BigQuery"
" table URI."
)
def _load_jsonl(filepath: str) -> "pd.DataFrame":
"""Loads data from a JSONL file into a DataFrame."""
try:
import pandas as pd
except ImportError:
raise ImportError(
'Pandas is not installed. Please install the SDK using "pip install'
' google-cloud-aiplatform[evaluation]"'
)
if filepath.startswith(_GCS_PREFIX):
file_contents = _read_gcs_file_contents(filepath)
return pd.read_json(file_contents, lines=True)
else:
with open(filepath, "r") as f:
return pd.read_json(f, lines=True)
def _load_csv(filepath: str) -> "pd.DataFrame":
"""Loads data from a CSV file into a DataFrame."""
try:
import pandas as pd
except ImportError:
raise ImportError(
'Pandas is not installed. Please install the SDK using "pip install'
' google-cloud-aiplatform[evaluation]"'
)
if filepath.startswith(_GCS_PREFIX):
file_contents = _read_gcs_file_contents(filepath)
return pd.read_csv(io.StringIO(file_contents), encoding="utf-8")
else:
return pd.read_csv(filepath, encoding="utf-8")
def _load_bigquery(table_id: str) -> "pd.DataFrame":
"""Loads data from a BigQuery table into a DataFrame."""
bigquery_client = bigquery.Client(project=initializer.global_config.project)
table = bigquery_client.get_table(table_id)
return bigquery_client.list_rows(table).to_dataframe()
def _read_gcs_file_contents(filepath: str) -> str:
"""Reads the contents of a file from Google Cloud Storage.
Args:
filepath: The GCS file path (e.g., 'gs://bucket_name/file.csv')
Returns:
str: The contents of the file.
"""
storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
bucket_name, blob_path = filepath[len(_GCS_PREFIX) :].split("/", 1)
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(blob_path)
return blob.download_as_string().decode("utf-8")
def _upload_file_to_gcs(upload_gcs_path: str, filename: str) -> None:
storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
storage.Blob.from_string(
uri=upload_gcs_path, client=storage_client
).upload_from_filename(filename)
def _upload_string_to_gcs(upload_gcs_path: str, contents: str) -> None:
"""Uploads the provided string to a GCS bucket."""
storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
storage.Blob.from_string(
uri=upload_gcs_path, client=storage_client
).upload_from_string(contents)
def _upload_pandas_df_to_gcs(
df: "pd.DataFrame", upload_gcs_path: str, file_type: str
) -> None:
"""Uploads the provided Pandas DataFrame to a GCS bucket.
Args:
df: The Pandas DataFrame to upload.
upload_gcs_path: The GCS path to upload the data file.
file_type: The file type of the data file.
"""
with tempfile.TemporaryDirectory() as temp_dir:
if file_type == "csv":
local_dataset_path = os.path.join(temp_dir, "metrics_table.csv")
df.to_csv(path_or_buf=local_dataset_path)
elif file_type == "jsonl":
local_dataset_path = os.path.join(temp_dir, "metrics_table.jsonl")
df.to_json(path_or_buf=local_dataset_path, orient="records", lines=True)
else:
raise ValueError(
f"Unsupported file type: {file_type} from {upload_gcs_path}."
" Please provide a valid GCS path with `jsonl` or `csv` suffix."
)
storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
storage.Blob.from_string(
uri=upload_gcs_path, client=storage_client
).upload_from_filename(filename=local_dataset_path)
def _upload_evaluation_summary_to_gcs(
summary_metrics: Dict[str, float],
upload_gcs_path: str,
candidate_model_name: Optional[str] = None,
baseline_model_name: Optional[str] = None,
dataset_uri: Optional[str] = None,
metrics: Optional[List[Union[str, metrics_base._Metric]]] = None,
) -> None:
"""Uploads the evaluation summary to a GCS bucket."""
summary = {
"summary_metrics": summary_metrics,
}
if candidate_model_name:
summary["candidate_model_name"] = candidate_model_name
if baseline_model_name:
summary["baseline_model_name"] = baseline_model_name
if dataset_uri:
summary["dataset_uri"] = dataset_uri
if metrics:
metric_descriptions = {}
for metric in metrics:
if isinstance(metric, metrics_base._ModelBasedMetric) and isinstance(
metric._raw_metric_prompt_template,
metric_prompt_template_base._MetricPromptTemplate,
):
metric_descriptions[metric.metric_name] = {
"criteria": metric._raw_metric_prompt_template._criteria,
"rating_rubric": metric._raw_metric_prompt_template._rating_rubric,
}
summary["metric_descriptions"] = metric_descriptions
with tempfile.TemporaryDirectory() as temp_dir:
local_summary_path = os.path.join(temp_dir, "summary_metrics.json")
json.dump(summary, open(local_summary_path, "w"))
_upload_file_to_gcs(upload_gcs_path, local_summary_path)
def upload_evaluation_results(
eval_result: eval_base.EvalResult,
destination_uri_prefix: str,
file_name: Optional[str] = None,
candidate_model_name: Optional[str] = None,
baseline_model_name: Optional[str] = None,
dataset_uri: Optional[str] = None,
metrics: Optional[List[Union[str, metrics_base._Metric]]] = None,
) -> None:
"""Uploads eval results to GCS destination.
Args:
eval_result: Eval results to upload.
destination_uri_prefix: GCS folder to store the data.
file_name: Optional. File name to store the metrics table.
candidate_model_name: Optional. Candidate model name.
baseline_model_name: Optional. Baseline model name.
dataset_uri: Optional. URI pointing to the dataset.
metrics: Optional. List of metrics used for evaluation.
"""
if not destination_uri_prefix:
_ipython_utils.display_gen_ai_evaluation_results_button()
return
if eval_result.metrics_table is None:
return
if destination_uri_prefix.startswith(_GCS_PREFIX):
if file_name:
base_name, extension = os.path.splitext(file_name)
file_type = extension.lower()[1:]
output_folder = destination_uri_prefix + "/" + base_name
metrics_table_path = output_folder + "/" + file_name
_upload_pandas_df_to_gcs(
eval_result.metrics_table, metrics_table_path, file_type
)
_upload_evaluation_summary_to_gcs(
eval_result.summary_metrics,
output_folder + "/summary_metrics.json",
candidate_model_name,
baseline_model_name,
dataset_uri,
metrics,
)
_ipython_utils.display_gen_ai_evaluation_results_button(
metrics_table_path.split(_GCS_PREFIX)[1]
)
else:
raise ValueError(
f"Unsupported destination URI: {destination_uri_prefix}."
f" Please provide a valid GCS bucket URI prefix starting with"
f" {_GCS_PREFIX}."
)
def initialize_metric_column_mapping(
metric_column_mapping: Optional[Dict[str, str]], dataset: "pd.DataFrame"
):
"""Initializes metric column mapping with dataset columns."""
initialized_metric_column_mapping = {}
for column in dataset.columns:
initialized_metric_column_mapping[column] = column
if metric_column_mapping:
for key, value in metric_column_mapping.items():
if key in initialized_metric_column_mapping:
_LOGGER.warning(
f"Cannot override `{key}` column with `{key}:{value}` mapping"
f" because `{key}` column is present in the evaluation"
" dataset. `metric_column_mapping` cannot override keys"
" that are already in evaluation dataset columns."
)
else:
initialized_metric_column_mapping[key] = value
return initialized_metric_column_mapping
def parse_intermediate_steps(intermediate_steps: List[Dict[str, Any]]):
"""Parses intermediate steps from the response to create trajectory."""
trajectory = []
try:
for step in intermediate_steps:
step_input, _ = step[0], step[1]
tool_name = step_input["kwargs"]["tool"]
tool_input = step_input["kwargs"]["tool_input"]
trajectory.append(
{
"tool_name": tool_name,
"tool_input": tool_input,
}
)
except Exception as e: # pylint: disable=broad-exception-caught
_LOGGER.error(
f"Failed to parse intermediate steps: {e}. The runnable you are using"
" is likely not compatible with the evaluation service. Please ensure"
" that the runnable you are using is compatible with the evaluation"
" service, if not, consider building a custom runnable function."
)
return trajectory
def parse_rubrics(rubric_generation_response: str) -> Dict[str, Any]:
"""Parses the rubric generation responses."""
try:
_, response = rubric_generation_response.split("```json")
except ValueError:
_LOGGER.warning(
"Failed to parse rubric generation response. Does not contain ```json"
)
return {"questions": ""}
try:
result = json.loads(response.strip("\n` "))
except json.JSONDecodeError:
_LOGGER.warning(
"Failed to parse rubric generation response. Does not contain valid"
" JSON."
)
return {"questions": ""}
return result
def parse_pairwise_rubric_verdict_pairs(prediction: str, regex: Pattern[str]) -> str:
"""Parses the pairwise rubric critique responses."""
response = "Unable to parse rubric verdict pairs from response."
response_matches = regex.findall(prediction)
if response_matches:
response_pairs = parse_question_blocks(response_matches[0])
response = "\n".join(f"{q}: {v}" for q, v in response_pairs)
return response
def parse_pairwise_rubric_result(
predictions: List[str],
) -> Dict[str, Any]:
"""Parses the pairwise rubric critique responses."""
prediction = predictions[0] # currently only supports one sample
rating_str = "Unable to parse verdict."
response_a = parse_pairwise_rubric_verdict_pairs(prediction, _RESPONSE_A_REGEX)
response_b = parse_pairwise_rubric_verdict_pairs(prediction, _RESPONSE_B_REGEX)
sxs_rating_matches = _SXS_RATING_REGEX.findall(prediction.replace(" ", ""))
if sxs_rating_matches:
rating_str = sxs_rating_matches[0].strip("[]")
rating_str = rating_str[rating_str.find(":") + 1 :]
return {
"pairwise_choice": (
RATING_TO_VERDICT[rating_str]
if rating_str in RATING_TO_VERDICT
else rating_str
),
"score": (
_RATING_TO_SCORE[rating_str] if rating_str in _RATING_TO_SCORE else None
),
"baseline_rubric_verdict_pairs": response_a,
"candidate_rubric_verdict_pairs": response_b,
"raw_outputs": predictions,
}
def parse_verdict(txt: str):
"""Parses the verdict from the rubric critique response."""
if not isinstance(txt, str) or not txt:
return None
try:
if verdict := _VERDICT_REGEX.findall(txt):
verdict = verdict[0]
if "yes" in verdict.lower():
return True
elif "no" in verdict.lower():
return False
except Exception: # pylint: disable=broad-exception-caught
return None
def parse_question(txt: str):
"""Parses the question from the rubric critique response."""
if not isinstance(txt, str) or not txt:
return None
try:
txt = txt.split("Verdict:")[0]
if "Question:" in txt:
return txt.split("Question:")[-1].strip()
if not (question := _QUESTION_REGEX.findall(txt)):
return txt.strip().split("\n")[0].removeprefix("STEP 1:").strip()
return question[0].strip()
except Exception: # pylint: disable=broad-exception-caught
return None
def parse_question_blocks(txt: str) -> List[Tuple[str, bool]]:
"""Parses the question blocks from the rubric critique response."""
if not txt.startswith("<question>\n"):
txt = "<question>\n" + txt
responses = []
question_blocks = _QUESTION_BLOCK_REGEX.findall(txt)
if not question_blocks:
question_blocks = [txt]
for block in question_blocks:
q = parse_question(block)
v = parse_verdict(block)
if q is not None and v is not None:
responses.append((q, v))
return responses
def parse_pointwise_rubric_result(results: List[str]) -> Dict[str, Any]:
"""Parses the pointwise rubric critique responses."""
self_consistency_results = {}
for sample_result in results:
rubric_verdict_pairs = parse_question_blocks(sample_result)
for rubric, verdict in rubric_verdict_pairs:
if rubric not in self_consistency_results:
self_consistency_results[rubric] = 0
self_consistency_results[rubric] += 1 if verdict else -1
rubric_results = {}
for rubric, verdict_counts in self_consistency_results.items():
rubric_results[rubric] = verdict_counts > 0
rubric_results_str = "\n".join(f"{q}: {v}" for q, v in rubric_results.items())
row_results = {
"score": (
sum(rubric_results.values()) / len(rubric_results) if rubric_results else 0
)
}
row_results["rubric_verdict_pairs"] = rubric_results_str
row_results["raw_outputs"] = results
return row_results