structure saas with tools
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from google.cloud.aiplatform.vizier.study import Study
|
||||
from google.cloud.aiplatform.vizier.trial import Trial
|
||||
|
||||
__all__ = (
|
||||
"Study",
|
||||
"Trial",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,192 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Cross-platform Vizier client interfaces.
|
||||
|
||||
Aside from "materialize_" methods, code written using these interfaces are
|
||||
compatible with OSS and Cloud Vertex Vizier. Note importantly that subclasses
|
||||
may have more methods than what is required by interfaces, and such methods
|
||||
are not cross compatible. Our recommendation is to explicitly type your objects
|
||||
to be `StudyInterface` or `TrialInterface` when you want to guarantee that
|
||||
a code block is cross-platform.
|
||||
|
||||
Keywords:
|
||||
|
||||
#Materialize: The method returns a deep copy of the underlying pyvizier object.
|
||||
Modifying the returned object does not update the Vizier service.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Collection, Type, TypeVar, Mapping, Any
|
||||
import abc
|
||||
|
||||
from google.cloud.aiplatform.vizier import pyvizier as vz
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ResourceNotFoundError(LookupError):
|
||||
"""Error raised by Vizier clients when resource is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TrialInterface(abc.ABC):
|
||||
"""Responsible for trial-level operations."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def uid(self) -> int:
|
||||
"""Unique identifier of the trial."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def parameters(self) -> Mapping[str, Any]:
|
||||
"""Parameters of the trial."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def status(self) -> vz.TrialStatus:
|
||||
"""Trial's status."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""Delete the Trial in Vizier service.
|
||||
|
||||
There is currently no promise on how this object behaves after `delete()`.
|
||||
If you are sharing a Trial object in parallel processes, proceed with
|
||||
caution.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def complete(
|
||||
self,
|
||||
measurement: Optional[vz.Measurement] = None,
|
||||
*,
|
||||
infeasible_reason: Optional[str] = None,
|
||||
) -> Optional[vz.Measurement]:
|
||||
"""Completes the trial and #materializes the measurement.
|
||||
|
||||
* If `measurement` is provided, then Vizier writes it as the trial's final
|
||||
measurement and returns it.
|
||||
* If `infeasible_reason` is provided, `measurement` is not needed.
|
||||
* If neither is provided, then Vizier selects an existing (intermediate)
|
||||
measurement to be the final measurement and returns it.
|
||||
|
||||
Args:
|
||||
measurement: Final measurement.
|
||||
infeasible_reason: Infeasible reason for missing final measurement.
|
||||
|
||||
Returns:
|
||||
The final measurement of the trial, or None if the trial is marked
|
||||
infeasible.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `measurement` nor `infeasible_reason` is provided
|
||||
but the trial does not contain any intermediate measurements.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Returns true if the trial should stop."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_measurement(self, measurement: vz.Measurement) -> None:
|
||||
"""Adds an intermediate measurement."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def materialize(self, *, include_all_measurements: bool = True) -> vz.Trial:
|
||||
"""#Materializes the Trial.
|
||||
|
||||
Args:
|
||||
include_all_measurements: If True, returned Trial includes all
|
||||
intermediate measurements. The final measurement is always provided.
|
||||
|
||||
Returns:
|
||||
Trial object.
|
||||
"""
|
||||
|
||||
|
||||
class StudyInterface(abc.ABC):
|
||||
"""Responsible for study-level operations."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_or_load(
|
||||
self, display_name: str, problem: vz.ProblemStatement
|
||||
) -> StudyInterface:
|
||||
""" """
|
||||
|
||||
@abc.abstractmethod
|
||||
def suggest(
|
||||
self, *, count: Optional[int] = None, worker: str = ""
|
||||
) -> Collection[TrialInterface]:
|
||||
"""Returns Trials to be evaluated by worker.
|
||||
|
||||
Args:
|
||||
count: Number of suggestions.
|
||||
worker: When new Trials are generated, their `assigned_worker` field is
|
||||
populated with this worker. suggest() first looks for existing Trials
|
||||
that are assigned to `worker`, before generating new ones.
|
||||
|
||||
Returns:
|
||||
Trials.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self) -> None:
|
||||
"""Deletes the study."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def trials(
|
||||
self, trial_filter: Optional[vz.TrialFilter] = None
|
||||
) -> Collection[TrialInterface]:
|
||||
"""Fetches a collection of trials."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_trial(self, uid: int) -> TrialInterface:
|
||||
"""Fetches a single trial.
|
||||
|
||||
Args:
|
||||
uid: Unique identifier of the trial within study.
|
||||
|
||||
Returns:
|
||||
Trial.
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundError: If trial does not exist.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def optimal_trials(self) -> Collection[TrialInterface]:
|
||||
"""Returns optimal trial(s)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def materialize_study_config(self) -> vz.StudyConfig:
|
||||
"""#Materializes the study config."""
|
||||
|
||||
@abc.abstractclassmethod
|
||||
def from_uid(cls: Type[_T], uid: str) -> _T:
|
||||
"""Fetches an existing study from the Vizier service.
|
||||
|
||||
Args:
|
||||
uid: Unique identifier of the study.
|
||||
|
||||
Returns:
|
||||
Study.
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundError: If study does not exist.
|
||||
"""
|
||||
@@ -0,0 +1,84 @@
|
||||
"""PyVizier classes for Pythia policies."""
|
||||
|
||||
try:
|
||||
from vizier.pyvizier import MetricInformation
|
||||
from vizier.pyvizier import MetricsConfig
|
||||
from vizier.pyvizier import MetricType
|
||||
from vizier.pyvizier import (
|
||||
ObjectiveMetricGoal,
|
||||
)
|
||||
from vizier.pyvizier import ProblemStatement
|
||||
from vizier.pyvizier import SearchSpace
|
||||
from vizier.pyvizier import (
|
||||
SearchSpaceSelector,
|
||||
)
|
||||
from vizier.pyvizier import Metadata
|
||||
from vizier.pyvizier import MetadataValue
|
||||
from vizier.pyvizier import Namespace
|
||||
from vizier.pyvizier import ExternalType
|
||||
from vizier.pyvizier import ParameterConfig
|
||||
from vizier.pyvizier import ParameterType
|
||||
from vizier.pyvizier import ScaleType
|
||||
from vizier.pyvizier import CompletedTrial
|
||||
from vizier.pyvizier import Measurement
|
||||
from vizier.pyvizier import MonotypeParameterSequence
|
||||
from vizier.pyvizier import Metric
|
||||
from vizier.pyvizier import ParameterDict
|
||||
from vizier.pyvizier import ParameterValue
|
||||
from vizier.pyvizier import Trial
|
||||
from vizier.pyvizier import ParameterValueTypes
|
||||
from vizier.pyvizier import TrialFilter
|
||||
from vizier.pyvizier import TrialStatus
|
||||
from vizier.pyvizier import TrialSuggestion
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google-vizier is not installed, and is required to use Vizier client."
|
||||
'Please install the SDK using "pip install google-vizier"'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.vizier.pyvizier.proto_converters import TrialConverter
|
||||
from google.cloud.aiplatform.vizier.pyvizier.proto_converters import (
|
||||
ParameterConfigConverter,
|
||||
)
|
||||
from google.cloud.aiplatform.vizier.pyvizier.proto_converters import (
|
||||
MeasurementConverter,
|
||||
)
|
||||
from google.cloud.aiplatform.vizier.pyvizier.study_config import StudyConfig
|
||||
from google.cloud.aiplatform.vizier.pyvizier.study_config import Algorithm
|
||||
from google.cloud.aiplatform.vizier.pyvizier.automated_stopping import (
|
||||
AutomatedStoppingConfig,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"MetricInformation",
|
||||
"MetricsConfig",
|
||||
"MetricType",
|
||||
"ObjectiveMetricGoal",
|
||||
"ProblemStatement",
|
||||
"SearchSpace",
|
||||
"SearchSpaceSelector",
|
||||
"Metadata",
|
||||
"MetadataValue",
|
||||
"Namespace",
|
||||
"ParameterConfigConverter",
|
||||
"ParameterValueTypes",
|
||||
"MeasurementConverter",
|
||||
"MonotypeParameterSequence",
|
||||
"TrialConverter",
|
||||
"StudyConfig",
|
||||
"Algorithm",
|
||||
"AutomatedStoppingConfig",
|
||||
"ExternalType",
|
||||
"ParameterConfig",
|
||||
"ParameterType",
|
||||
"ScaleType",
|
||||
"CompletedTrial",
|
||||
"Measurement",
|
||||
"Metric",
|
||||
"ParameterDict",
|
||||
"ParameterValue",
|
||||
"Trial",
|
||||
"TrialFilter",
|
||||
"TrialStatus",
|
||||
"TrialSuggestion",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,79 @@
|
||||
"""Convenience classes for configuring Vizier Early-Stopping Configs."""
|
||||
import copy
|
||||
from typing import Union
|
||||
|
||||
import attr
|
||||
|
||||
from google.cloud.aiplatform.compat.types import study as study_pb2
|
||||
|
||||
AutomatedStoppingConfigProto = Union[
|
||||
study_pb2.StudySpec.DecayCurveAutomatedStoppingSpec,
|
||||
study_pb2.StudySpec.MedianAutomatedStoppingSpec,
|
||||
]
|
||||
|
||||
|
||||
@attr.s(frozen=True, init=True, slots=True, kw_only=True)
|
||||
class AutomatedStoppingConfig:
|
||||
"""A wrapper for study_pb2.automated_stopping_spec."""
|
||||
|
||||
_proto: AutomatedStoppingConfigProto = attr.ib(init=True, kw_only=True)
|
||||
|
||||
@classmethod
|
||||
def decay_curve_stopping_config(cls, use_steps: bool) -> "AutomatedStoppingConfig":
|
||||
"""Create a DecayCurve automated stopping config.
|
||||
|
||||
Vizier will early stop the Trial if it predicts the Trial objective value
|
||||
will not be better than previous Trials.
|
||||
|
||||
Args:
|
||||
use_steps: Bool. If set, use Measurement.step_count as the measure of
|
||||
training progress. Otherwise, use Measurement.elapsed_duration.
|
||||
|
||||
Returns:
|
||||
AutomatedStoppingConfig object.
|
||||
|
||||
Raises:
|
||||
ValueError: If more than one metric is configured.
|
||||
Note that Vizier Early Stopping currently only supports single-objective
|
||||
studies.
|
||||
"""
|
||||
config = study_pb2.StudySpec.DecayCurveAutomatedStoppingSpec(
|
||||
use_elapsed_duration=not use_steps
|
||||
)
|
||||
return cls(proto=config)
|
||||
|
||||
@classmethod
|
||||
def median_automated_stopping_config(
|
||||
cls, use_steps: bool
|
||||
) -> "AutomatedStoppingConfig":
|
||||
"""Create a Median automated stopping config.
|
||||
|
||||
Vizier will early stop the Trial if it predicts the Trial objective value
|
||||
will not be better than previous Trials.
|
||||
|
||||
Args:
|
||||
use_steps: Bool. If set, use Measurement.step_count as the measure of
|
||||
training progress. Otherwise, use Measurement.elapsed_duration.
|
||||
|
||||
Returns:
|
||||
AutomatedStoppingConfig object.
|
||||
|
||||
Raises:
|
||||
ValueError: If more than one metric is configured.
|
||||
Note that Vizier Early Stopping currently only supports single-objective
|
||||
studies.
|
||||
"""
|
||||
config = study_pb2.StudySpec.MedianAutomatedStoppingSpec(
|
||||
use_elapsed_duration=not use_steps
|
||||
)
|
||||
return cls(proto=config)
|
||||
|
||||
@classmethod
|
||||
def from_proto(
|
||||
cls, proto: AutomatedStoppingConfigProto
|
||||
) -> "AutomatedStoppingConfig":
|
||||
return cls(proto=proto)
|
||||
|
||||
def to_proto(self) -> AutomatedStoppingConfigProto:
|
||||
"""Returns this object as a proto."""
|
||||
return copy.deepcopy(self._proto)
|
||||
@@ -0,0 +1,525 @@
|
||||
"""Converters for OSS Vizier's protos from/to PyVizier's classes."""
|
||||
import logging
|
||||
from datetime import timezone
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.protobuf import duration_pb2
|
||||
from google.protobuf import struct_pb2
|
||||
from google.protobuf import timestamp_pb2
|
||||
from google.cloud.aiplatform.compat.types import study as study_pb2
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ExternalType
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ScaleType
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ParameterType
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ParameterValue
|
||||
from google.cloud.aiplatform.vizier.pyvizier import MonotypeParameterSequence
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ParameterConfig
|
||||
from google.cloud.aiplatform.vizier.pyvizier import Measurement
|
||||
from google.cloud.aiplatform.vizier.pyvizier import Metric
|
||||
from google.cloud.aiplatform.vizier.pyvizier import TrialStatus
|
||||
from google.cloud.aiplatform.vizier.pyvizier import Trial
|
||||
|
||||
_ScaleTypePb2 = study_pb2.StudySpec.ParameterSpec.ScaleType
|
||||
|
||||
|
||||
class _ScaleTypeMap:
|
||||
"""Proto converter for scale type."""
|
||||
|
||||
_pyvizier_to_proto = {
|
||||
ScaleType.LINEAR: _ScaleTypePb2.UNIT_LINEAR_SCALE,
|
||||
ScaleType.LOG: _ScaleTypePb2.UNIT_LOG_SCALE,
|
||||
ScaleType.REVERSE_LOG: _ScaleTypePb2.UNIT_REVERSE_LOG_SCALE,
|
||||
}
|
||||
_proto_to_pyvizier = {v: k for k, v in _pyvizier_to_proto.items()}
|
||||
|
||||
@classmethod
|
||||
def to_proto(cls, pyvizier: ScaleType) -> _ScaleTypePb2:
|
||||
return cls._pyvizier_to_proto[pyvizier]
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: _ScaleTypePb2) -> ScaleType:
|
||||
return cls._proto_to_pyvizier[proto]
|
||||
|
||||
|
||||
class ParameterConfigConverter:
|
||||
"""Converter for ParameterConfig."""
|
||||
|
||||
@classmethod
|
||||
def _set_bounds(
|
||||
cls,
|
||||
proto: study_pb2.StudySpec.ParameterSpec,
|
||||
lower: float,
|
||||
upper: float,
|
||||
parameter_type: ParameterType,
|
||||
):
|
||||
"""Sets the proto's min_value and max_value fields."""
|
||||
if parameter_type == ParameterType.INTEGER:
|
||||
proto.integer_value_spec.min_value = lower
|
||||
proto.integer_value_spec.max_value = upper
|
||||
elif parameter_type == ParameterType.DOUBLE:
|
||||
proto.double_value_spec.min_value = lower
|
||||
proto.double_value_spec.max_value = upper
|
||||
|
||||
@classmethod
|
||||
def _set_feasible_points(
|
||||
cls, proto: study_pb2.StudySpec.ParameterSpec, feasible_points: Sequence[float]
|
||||
):
|
||||
"""Sets the proto's feasible_points field."""
|
||||
feasible_points = sorted(feasible_points)
|
||||
proto.discrete_value_spec.values.clear()
|
||||
proto.discrete_value_spec.values.extend(feasible_points)
|
||||
|
||||
@classmethod
|
||||
def _set_categories(
|
||||
cls, proto: study_pb2.StudySpec.ParameterSpec, categories: Sequence[str]
|
||||
):
|
||||
"""Sets the protos' categories field."""
|
||||
proto.categorical_value_spec.values.clear()
|
||||
proto.categorical_value_spec.values.extend(categories)
|
||||
|
||||
@classmethod
|
||||
def _set_default_value(
|
||||
cls,
|
||||
proto: study_pb2.StudySpec.ParameterSpec,
|
||||
default_value: Union[float, int, str],
|
||||
):
|
||||
"""Sets the protos' default_value field."""
|
||||
which_pv_spec = proto._pb.WhichOneof("parameter_value_spec")
|
||||
getattr(proto, which_pv_spec).default_value = default_value
|
||||
|
||||
@classmethod
|
||||
def _matching_parent_values(
|
||||
cls, proto: study_pb2.StudySpec.ParameterSpec.ConditionalParameterSpec
|
||||
) -> MonotypeParameterSequence:
|
||||
"""Returns the matching parent values, if set."""
|
||||
oneof_name = proto.WhichOneof("parent_value_condition")
|
||||
if not oneof_name:
|
||||
return []
|
||||
if oneof_name in (
|
||||
"parent_discrete_values",
|
||||
"parent_int_values",
|
||||
"parent_categorical_values",
|
||||
):
|
||||
return list(getattr(getattr(proto, oneof_name), "values"))
|
||||
raise ValueError("Unknown matching_parent_vals: {}".format(oneof_name))
|
||||
|
||||
@classmethod
|
||||
def from_proto(
|
||||
cls,
|
||||
proto: study_pb2.StudySpec.ParameterSpec,
|
||||
*,
|
||||
strict_validation: bool = False
|
||||
) -> ParameterConfig:
|
||||
"""Creates a ParameterConfig.
|
||||
|
||||
Args:
|
||||
proto:
|
||||
strict_validation: If True, raise ValueError to enforce that
|
||||
from_proto(proto).to_proto == proto.
|
||||
|
||||
Returns:
|
||||
ParameterConfig object
|
||||
|
||||
Raises:
|
||||
ValueError: See the "strict_validtion" arg documentation.
|
||||
"""
|
||||
feasible_values = []
|
||||
external_type = ExternalType.INTERNAL
|
||||
oneof_name = proto._pb.WhichOneof("parameter_value_spec")
|
||||
if oneof_name == "integer_value_spec":
|
||||
bounds = (
|
||||
int(proto.integer_value_spec.min_value),
|
||||
int(proto.integer_value_spec.max_value),
|
||||
)
|
||||
external_type = ExternalType.INTEGER
|
||||
elif oneof_name == "double_value_spec":
|
||||
bounds = (
|
||||
proto.double_value_spec.min_value,
|
||||
proto.double_value_spec.max_value,
|
||||
)
|
||||
elif oneof_name == "discrete_value_spec":
|
||||
bounds = None
|
||||
feasible_values = proto.discrete_value_spec.values
|
||||
elif oneof_name == "categorical_value_spec":
|
||||
bounds = None
|
||||
feasible_values = proto.categorical_value_spec.values
|
||||
# Boolean values are encoded as categoricals, check for the special
|
||||
# hard-coded values.
|
||||
boolean_values = ["False", "True"]
|
||||
if sorted(list(feasible_values)) == boolean_values:
|
||||
external_type = ExternalType.BOOLEAN
|
||||
|
||||
default_value = None
|
||||
if getattr(proto, oneof_name).default_value:
|
||||
default_value = getattr(proto, oneof_name).default_value
|
||||
if external_type == ExternalType.INTEGER:
|
||||
default_value = int(default_value)
|
||||
|
||||
if proto.conditional_parameter_specs:
|
||||
children = []
|
||||
for conditional_ps in proto.conditional_parameter_specs:
|
||||
parent_values = cls._matching_parent_values(conditional_ps)
|
||||
children.append(
|
||||
(parent_values, cls.from_proto(conditional_ps.parameter_spec))
|
||||
)
|
||||
else:
|
||||
children = None
|
||||
|
||||
scale_type = None
|
||||
if proto.scale_type:
|
||||
scale_type = _ScaleTypeMap.from_proto(proto.scale_type)
|
||||
|
||||
try:
|
||||
config = ParameterConfig.factory(
|
||||
name=proto.parameter_id,
|
||||
feasible_values=feasible_values,
|
||||
bounds=bounds,
|
||||
children=children,
|
||||
scale_type=scale_type,
|
||||
default_value=default_value,
|
||||
external_type=external_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"The provided proto was misconfigured. {}".format(proto)
|
||||
) from e
|
||||
|
||||
if strict_validation and cls.to_proto(config) != proto:
|
||||
raise ValueError(
|
||||
"The provided proto was misconfigured. Expected: {} Given: {}".format(
|
||||
cls.to_proto(config), proto
|
||||
)
|
||||
)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _set_child_parameter_configs(
|
||||
cls,
|
||||
parent_proto: study_pb2.StudySpec.ParameterSpec,
|
||||
pc: ParameterConfig,
|
||||
):
|
||||
"""Sets the parent_proto's conditional_parameter_specs field.
|
||||
|
||||
Args:
|
||||
parent_proto: Modified in place.
|
||||
pc: Parent ParameterConfig to copy children from.
|
||||
|
||||
Raises:
|
||||
ValueError: If the child configs are invalid
|
||||
"""
|
||||
children: List[Tuple[MonotypeParameterSequence, ParameterConfig]] = []
|
||||
for child in pc.child_parameter_configs:
|
||||
children.append((child.matching_parent_values, child))
|
||||
if not children:
|
||||
return
|
||||
parent_proto.conditional_parameter_specs.clear()
|
||||
for child_pair in children:
|
||||
if len(child_pair) != 2:
|
||||
raise ValueError(
|
||||
"""Each element in children must be a tuple of
|
||||
(Sequence of valid parent values, ParameterConfig)"""
|
||||
)
|
||||
|
||||
logging.debug(
|
||||
"_set_child_parameter_configs: parent_proto=%s, children=%s",
|
||||
parent_proto,
|
||||
children,
|
||||
)
|
||||
for unsorted_parent_values, child in children:
|
||||
parent_values = sorted(unsorted_parent_values)
|
||||
child_proto = cls.to_proto(child.clone_without_children)
|
||||
conditional_parameter_spec = (
|
||||
study_pb2.StudySpec.ParameterSpec.ConditionalParameterSpec(
|
||||
parameter_spec=child_proto
|
||||
)
|
||||
)
|
||||
|
||||
if "discrete_value_spec" in parent_proto:
|
||||
conditional_parameter_spec.parent_discrete_values.values[
|
||||
:
|
||||
] = parent_values
|
||||
elif "categorical_value_spec" in parent_proto:
|
||||
conditional_parameter_spec.parent_categorical_values.values[
|
||||
:
|
||||
] = parent_values
|
||||
elif "integer_value_spec" in parent_proto:
|
||||
conditional_parameter_spec.parent_int_values.values[:] = parent_values
|
||||
else:
|
||||
raise ValueError("DOUBLE type cannot have child parameters")
|
||||
if child.child_parameter_configs:
|
||||
cls._set_child_parameter_configs(child_proto, child)
|
||||
parent_proto.conditional_parameter_specs.extend(
|
||||
[conditional_parameter_spec]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_proto(cls, pc: ParameterConfig) -> study_pb2.StudySpec.ParameterSpec:
|
||||
"""Returns a ParameterConfig Proto."""
|
||||
proto = study_pb2.StudySpec.ParameterSpec(parameter_id=pc.name)
|
||||
if pc.type == ParameterType.DISCRETE:
|
||||
cls._set_feasible_points(proto, [float(v) for v in pc.feasible_values])
|
||||
elif pc.type == ParameterType.CATEGORICAL:
|
||||
cls._set_categories(proto, pc.feasible_values)
|
||||
elif pc.type in (ParameterType.INTEGER, ParameterType.DOUBLE):
|
||||
cls._set_bounds(proto, pc.bounds[0], pc.bounds[1], pc.type)
|
||||
else:
|
||||
raise ValueError("Invalid ParameterConfig: {}".format(pc))
|
||||
if pc.scale_type is not None and pc.scale_type != ScaleType.UNIFORM_DISCRETE:
|
||||
proto.scale_type = _ScaleTypeMap.to_proto(pc.scale_type)
|
||||
if pc.default_value is not None:
|
||||
cls._set_default_value(proto, pc.default_value)
|
||||
|
||||
cls._set_child_parameter_configs(proto, pc)
|
||||
return proto
|
||||
|
||||
|
||||
class ParameterValueConverter:
|
||||
"""Converter for ParameterValue."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.Trial.Parameter) -> Optional[ParameterValue]:
|
||||
"""Returns whichever value that is populated, or None."""
|
||||
potential_value = proto.value
|
||||
if (
|
||||
isinstance(potential_value, float)
|
||||
or isinstance(potential_value, str)
|
||||
or isinstance(potential_value, bool)
|
||||
):
|
||||
return ParameterValue(potential_value)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def to_proto(
|
||||
cls, parameter_value: ParameterValue, name: str
|
||||
) -> study_pb2.Trial.Parameter:
|
||||
"""Returns Parameter Proto."""
|
||||
if isinstance(parameter_value.value, int):
|
||||
value = struct_pb2.Value(number_value=parameter_value.value)
|
||||
elif isinstance(parameter_value.value, bool):
|
||||
value = struct_pb2.Value(bool_value=parameter_value.value)
|
||||
elif isinstance(parameter_value.value, float):
|
||||
value = struct_pb2.Value(number_value=parameter_value.value)
|
||||
elif isinstance(parameter_value.value, str):
|
||||
value = struct_pb2.Value(string_value=parameter_value.value)
|
||||
|
||||
proto = study_pb2.Trial.Parameter(parameter_id=name, value=value)
|
||||
return proto
|
||||
|
||||
|
||||
class MeasurementConverter:
|
||||
"""Converter for MeasurementConverter."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.Measurement) -> Measurement:
|
||||
"""Creates a valid instance from proto.
|
||||
|
||||
Args:
|
||||
proto: Measurement proto.
|
||||
|
||||
Returns:
|
||||
A valid instance of Measurement object. Metrics with invalid values
|
||||
are automatically filtered out.
|
||||
"""
|
||||
|
||||
metrics = dict()
|
||||
|
||||
for metric in proto.metrics:
|
||||
if (
|
||||
metric.metric_id in metrics
|
||||
and metrics[metric.metric_id].value != metric.value
|
||||
):
|
||||
logging.log_first_n(
|
||||
logging.ERROR,
|
||||
'Duplicate metric of name "%s".'
|
||||
"The newly found value %s will be used and "
|
||||
"the previously found value %s will be discarded."
|
||||
"This always happens if the proto has an empty-named metric.",
|
||||
5,
|
||||
metric.metric_id,
|
||||
metric.value,
|
||||
metrics[metric.metric_id].value,
|
||||
)
|
||||
try:
|
||||
metrics[metric.metric_id] = Metric(value=metric.value)
|
||||
except ValueError:
|
||||
pass
|
||||
return Measurement(
|
||||
metrics=metrics,
|
||||
elapsed_secs=proto.elapsed_duration.seconds,
|
||||
steps=proto.step_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_proto(cls, measurement: Measurement) -> study_pb2.Measurement:
|
||||
"""Converts to Measurement proto."""
|
||||
int_seconds = int(measurement.elapsed_secs)
|
||||
proto = study_pb2.Measurement(
|
||||
elapsed_duration=duration_pb2.Duration(
|
||||
seconds=int_seconds,
|
||||
nanos=int(1e9 * (measurement.elapsed_secs - int_seconds)),
|
||||
)
|
||||
)
|
||||
for name, metric in measurement.metrics.items():
|
||||
proto.metrics.append(
|
||||
study_pb2.Measurement.Metric(metric_id=name, value=metric.value)
|
||||
)
|
||||
|
||||
proto.step_count = measurement.steps
|
||||
return proto
|
||||
|
||||
|
||||
def _to_pyvizier_trial_status(proto_state: study_pb2.Trial.State) -> TrialStatus:
|
||||
"""from_proto conversion for Trial statuses."""
|
||||
if proto_state == study_pb2.Trial.State.REQUESTED:
|
||||
return TrialStatus.REQUESTED
|
||||
elif proto_state == study_pb2.Trial.State.ACTIVE:
|
||||
return TrialStatus.ACTIVE
|
||||
if proto_state == study_pb2.Trial.State.STOPPING:
|
||||
return TrialStatus.STOPPING
|
||||
if proto_state == study_pb2.Trial.State.SUCCEEDED:
|
||||
return TrialStatus.COMPLETED
|
||||
elif proto_state == study_pb2.Trial.State.INFEASIBLE:
|
||||
return TrialStatus.COMPLETED
|
||||
else:
|
||||
return TrialStatus.UNKNOWN
|
||||
|
||||
|
||||
def _from_pyvizier_trial_status(
|
||||
status: TrialStatus, infeasible: bool
|
||||
) -> study_pb2.Trial.State:
|
||||
"""to_proto conversion for Trial states."""
|
||||
if status == TrialStatus.REQUESTED:
|
||||
return study_pb2.Trial.State.REQUESTED
|
||||
elif status == TrialStatus.ACTIVE:
|
||||
return study_pb2.Trial.State.ACTIVE
|
||||
elif status == TrialStatus.STOPPING:
|
||||
return study_pb2.Trial.State.STOPPING
|
||||
elif status == TrialStatus.COMPLETED:
|
||||
if infeasible:
|
||||
return study_pb2.Trial.State.INFEASIBLE
|
||||
else:
|
||||
return study_pb2.Trial.State.SUCCEEDED
|
||||
else:
|
||||
return study_pb2.Trial.State.STATE_UNSPECIFIED
|
||||
|
||||
|
||||
class TrialConverter:
|
||||
"""Converter for TrialConverter."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.Trial) -> Trial:
|
||||
"""Converts from Trial proto to object.
|
||||
|
||||
Args:
|
||||
proto: Trial proto.
|
||||
|
||||
Returns:
|
||||
A Trial object.
|
||||
"""
|
||||
parameters = {}
|
||||
for parameter in proto.parameters:
|
||||
value = ParameterValueConverter.from_proto(parameter)
|
||||
if value is not None:
|
||||
if parameter.parameter_id in parameters:
|
||||
raise ValueError(
|
||||
"Invalid trial proto contains duplicate parameter {}"
|
||||
": {}".format(parameter.parameter_id, proto)
|
||||
)
|
||||
parameters[parameter.parameter_id] = value
|
||||
else:
|
||||
logging.warning(
|
||||
"A parameter without a value will be dropped: %s", parameter
|
||||
)
|
||||
|
||||
final_measurement = None
|
||||
if proto.final_measurement:
|
||||
final_measurement = MeasurementConverter.from_proto(proto.final_measurement)
|
||||
|
||||
completion_time = None
|
||||
infeasibility_reason = None
|
||||
if proto.state == study_pb2.Trial.State.SUCCEEDED:
|
||||
if proto.end_time:
|
||||
completion_time = (
|
||||
proto.end_time.timestamp_pb()
|
||||
.ToDatetime()
|
||||
.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
elif proto.state == study_pb2.Trial.State.INFEASIBLE:
|
||||
infeasibility_reason = proto.infeasible_reason
|
||||
|
||||
measurements = []
|
||||
for measure in proto.measurements:
|
||||
measurements.append(MeasurementConverter.from_proto(measure))
|
||||
|
||||
creation_time = None
|
||||
if proto.start_time:
|
||||
creation_time = (
|
||||
proto.start_time.timestamp_pb()
|
||||
.ToDatetime()
|
||||
.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
return Trial(
|
||||
id=int(proto.name.split("/")[-1]),
|
||||
description=proto.name,
|
||||
assigned_worker=proto.client_id or None,
|
||||
is_requested=proto.state == study_pb2.Trial.State.REQUESTED,
|
||||
stopping_reason=(
|
||||
"stopping reason not supported yet"
|
||||
if proto.state == study_pb2.Trial.State.STOPPING
|
||||
else None
|
||||
),
|
||||
parameters=parameters,
|
||||
creation_time=creation_time,
|
||||
completion_time=completion_time,
|
||||
infeasibility_reason=infeasibility_reason,
|
||||
final_measurement=final_measurement,
|
||||
measurements=measurements,
|
||||
) # pytype: disable=wrong-arg-types
|
||||
|
||||
@classmethod
|
||||
def from_protos(cls, protos: Sequence[study_pb2.Trial]) -> List[Trial]:
|
||||
"""Convenience wrapper for from_proto."""
|
||||
return [TrialConverter.from_proto(proto) for proto in protos]
|
||||
|
||||
@classmethod
|
||||
def to_protos(cls, pytrials: Sequence[Trial]) -> List[study_pb2.Trial]:
|
||||
return [TrialConverter.to_proto(pytrial) for pytrial in pytrials]
|
||||
|
||||
@classmethod
|
||||
def to_proto(cls, pytrial: Trial) -> study_pb2.Trial:
|
||||
"""Converts a pyvizier Trial to a Trial proto."""
|
||||
proto = study_pb2.Trial()
|
||||
if pytrial.description is not None:
|
||||
proto.name = pytrial.description
|
||||
proto.id = str(pytrial.id)
|
||||
proto.state = _from_pyvizier_trial_status(pytrial.status, pytrial.infeasible)
|
||||
proto.client_id = pytrial.assigned_worker or ""
|
||||
|
||||
for name, value in pytrial.parameters.items():
|
||||
proto.parameters.append(ParameterValueConverter.to_proto(value, name))
|
||||
|
||||
# pytrial always adds an empty metric. Ideally, we should remove it if the
|
||||
# metric does not exist in the study config.
|
||||
# setattr() is required here as `proto.final_measurement.CopyFrom`
|
||||
# raises AttributeErrors when setting the field on the pb2 compat types.
|
||||
if pytrial.final_measurement is not None:
|
||||
setattr(
|
||||
proto,
|
||||
"final_measurement",
|
||||
MeasurementConverter.to_proto(pytrial.final_measurement),
|
||||
)
|
||||
|
||||
for measurement in pytrial.measurements:
|
||||
proto.measurements.append(MeasurementConverter.to_proto(measurement))
|
||||
|
||||
if pytrial.creation_time is not None:
|
||||
start_time = timestamp_pb2.Timestamp()
|
||||
start_time.FromDatetime(pytrial.creation_time)
|
||||
setattr(proto, "start_time", start_time)
|
||||
if pytrial.completion_time is not None:
|
||||
end_time = timestamp_pb2.Timestamp()
|
||||
end_time.FromDatetime(pytrial.completion_time)
|
||||
setattr(proto, "end_time", end_time)
|
||||
if pytrial.infeasibility_reason is not None:
|
||||
proto.infeasible_reason = pytrial.infeasibility_reason
|
||||
return proto
|
||||
@@ -0,0 +1,469 @@
|
||||
"""Convenience classes for configuring Vizier Study Configs and Search Spaces.
|
||||
|
||||
This module contains several classes, used to access/build Vizier StudyConfig
|
||||
protos:
|
||||
* `StudyConfig` class is the main class, which:
|
||||
1) Allows to easily build Vizier StudyConfig protos via a convenient
|
||||
Python API.
|
||||
2) Can be initialized from an existing StudyConfig proto, to enable easy
|
||||
Pythonic accessors to information contained in StudyConfig protos,
|
||||
and easy field editing capabilities.
|
||||
|
||||
* `SearchSpace` and `SearchSpaceSelector` classes deals with Vizier search
|
||||
spaces. Both flat spaces and conditional parameters are supported.
|
||||
"""
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import attr
|
||||
from google.cloud.aiplatform.vizier.pyvizier.automated_stopping import (
|
||||
AutomatedStoppingConfig,
|
||||
)
|
||||
from google.cloud.aiplatform.vizier.pyvizier import proto_converters
|
||||
from google.cloud.aiplatform.vizier.pyvizier import SearchSpace
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ProblemStatement
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ObjectiveMetricGoal
|
||||
from google.cloud.aiplatform.vizier.pyvizier import SearchSpaceSelector
|
||||
from google.cloud.aiplatform.vizier.pyvizier import MetricsConfig
|
||||
from google.cloud.aiplatform.vizier.pyvizier import MetricInformation
|
||||
from google.cloud.aiplatform.vizier.pyvizier import Trial
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ParameterValueTypes
|
||||
from google.cloud.aiplatform.vizier.pyvizier import ParameterConfig
|
||||
from google.cloud.aiplatform.compat.types import study as study_pb2
|
||||
|
||||
################### PyTypes ###################
|
||||
# A sequence of possible internal parameter values.
|
||||
# Possible types for trial parameter values after cast to external types.
|
||||
ParameterValueSequence = Union[
|
||||
ParameterValueTypes,
|
||||
Sequence[int],
|
||||
Sequence[float],
|
||||
Sequence[str],
|
||||
Sequence[bool],
|
||||
]
|
||||
|
||||
################### Enums ###################
|
||||
|
||||
|
||||
class Algorithm(enum.Enum):
|
||||
"""Valid Values for StudyConfig.Algorithm."""
|
||||
|
||||
ALGORITHM_UNSPECIFIED = study_pb2.StudySpec.Algorithm.ALGORITHM_UNSPECIFIED
|
||||
# GAUSSIAN_PROCESS_BANDIT = study_pb2.StudySpec.Algorithm.GAUSSIAN_PROCESS_BANDIT
|
||||
GRID_SEARCH = study_pb2.StudySpec.Algorithm.GRID_SEARCH
|
||||
RANDOM_SEARCH = study_pb2.StudySpec.Algorithm.RANDOM_SEARCH
|
||||
# NSGA2 = study_pb2.StudySpec.Algorithm.NSGA2
|
||||
|
||||
|
||||
class ObservationNoise(enum.Enum):
|
||||
"""Valid Values for StudyConfig.ObservationNoise."""
|
||||
|
||||
OBSERVATION_NOISE_UNSPECIFIED = (
|
||||
study_pb2.StudySpec.ObservationNoise.OBSERVATION_NOISE_UNSPECIFIED
|
||||
)
|
||||
LOW = study_pb2.StudySpec.ObservationNoise.LOW
|
||||
HIGH = study_pb2.StudySpec.ObservationNoise.HIGH
|
||||
|
||||
|
||||
################### Classes For Various Config Protos ###################
|
||||
@attr.define(frozen=False, init=True, slots=True, kw_only=True)
|
||||
class MetricInformationConverter:
|
||||
"""A wrapper for vizier_pb2.MetricInformation."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.StudySpec.MetricSpec) -> MetricInformation:
|
||||
"""Converts a MetricInformation proto to a MetricInformation object."""
|
||||
if proto.goal not in list(ObjectiveMetricGoal):
|
||||
raise ValueError("Unknown MetricInformation.goal: {}".format(proto.goal))
|
||||
|
||||
return MetricInformation(
|
||||
name=proto.metric_id,
|
||||
goal=proto.goal,
|
||||
safety_threshold=None,
|
||||
safety_std_threshold=None,
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_proto(cls, obj: MetricInformation) -> study_pb2.StudySpec.MetricSpec:
|
||||
"""Returns this object as a proto."""
|
||||
return study_pb2.StudySpec.MetricSpec(metric_id=obj.name, goal=obj.goal.value)
|
||||
|
||||
|
||||
class MetricsConfig(MetricsConfig):
|
||||
"""Metrics config."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(
|
||||
cls, protos: Iterable[study_pb2.StudySpec.MetricSpec]
|
||||
) -> "MetricsConfig":
|
||||
return cls(MetricInformationConverter.from_proto(m) for m in protos)
|
||||
|
||||
def to_proto(self) -> List[study_pb2.StudySpec.MetricSpec]:
|
||||
return [MetricInformationConverter.to_proto(metric) for metric in self]
|
||||
|
||||
|
||||
SearchSpaceSelector = SearchSpaceSelector
|
||||
|
||||
|
||||
@attr.define(frozen=True, init=True, slots=True, kw_only=True)
|
||||
class SearchSpace(SearchSpace):
|
||||
"""A Selector for all, or part of a SearchSpace."""
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.StudySpec) -> "SearchSpace":
|
||||
"""Extracts a SearchSpace object from a StudyConfig proto."""
|
||||
|
||||
# For google-vizier <= 0.0.15
|
||||
if hasattr(cls, "_factory"):
|
||||
parameter_configs = []
|
||||
for pc in proto.parameters:
|
||||
parameter_configs.append(
|
||||
proto_converters.ParameterConfigConverter.from_proto(pc)
|
||||
)
|
||||
return cls._factory(parameter_configs=parameter_configs)
|
||||
|
||||
result = cls()
|
||||
for pc in proto.parameters:
|
||||
result.add(proto_converters.ParameterConfigConverter.from_proto(pc))
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def parameter_protos(self) -> List[study_pb2.StudySpec.ParameterSpec]:
|
||||
"""Returns the search space as a List of ParameterConfig protos."""
|
||||
|
||||
# For google-vizier <= 0.0.15
|
||||
if isinstance(self._parameter_configs, list):
|
||||
return [
|
||||
proto_converters.ParameterConfigConverter.to_proto(pc)
|
||||
for pc in self._parameter_configs
|
||||
]
|
||||
|
||||
return [
|
||||
proto_converters.ParameterConfigConverter.to_proto(pc)
|
||||
for _, pc in self._parameter_configs.items()
|
||||
]
|
||||
|
||||
|
||||
################### Main Class ###################
|
||||
#
|
||||
# A StudyConfig object can be initialized:
|
||||
# (1) From a StudyConfig proto using StudyConfig.from_proto():
|
||||
# study_config_proto = study_pb2.StudySpec(...)
|
||||
# study_config = pyvizier.StudyConfig.from_proto(study_config_proto)
|
||||
# # Attributes can be modified.
|
||||
# new_proto = study_config.to_proto()
|
||||
#
|
||||
# (2) By directly calling __init__ and setting attributes:
|
||||
# study_config = pyvizier.StudyConfig(
|
||||
# metric_information=[pyvizier.MetricInformation(
|
||||
# name='accuracy', goal=pyvizier.ObjectiveMetricGoal.MAXIMIZE)],
|
||||
# search_space=SearchSpace.from_proto(proto),
|
||||
# )
|
||||
# # OR:
|
||||
# study_config = pyvizier.StudyConfig()
|
||||
# study_config.metric_information.append(
|
||||
# pyvizier.MetricInformation(
|
||||
# name='accuracy', goal=pyvizier.ObjectiveMetricGoal.MAXIMIZE))
|
||||
#
|
||||
# # Since building a search space is more involved, get a reference to the
|
||||
# # search space, and add parameters to it.
|
||||
# root = study_config.search_space.select_root()
|
||||
# root.add_float_param('learning_rate', 0.001, 1.0,
|
||||
# scale_type=pyvizier.ScaleType.LOG)
|
||||
#
|
||||
@attr.define(frozen=False, init=True, slots=True, kw_only=True)
|
||||
class StudyConfig(ProblemStatement):
|
||||
"""A builder and wrapper for study_pb2.StudySpec proto."""
|
||||
|
||||
search_space: SearchSpace = attr.field(
|
||||
init=True,
|
||||
factory=SearchSpace,
|
||||
validator=attr.validators.instance_of(SearchSpace),
|
||||
on_setattr=attr.setters.validate,
|
||||
)
|
||||
|
||||
algorithm: Algorithm = attr.field(
|
||||
init=True,
|
||||
validator=attr.validators.instance_of(Algorithm),
|
||||
on_setattr=[attr.setters.convert, attr.setters.validate],
|
||||
default=Algorithm.ALGORITHM_UNSPECIFIED,
|
||||
kw_only=True,
|
||||
)
|
||||
|
||||
metric_information: MetricsConfig = attr.field(
|
||||
init=True,
|
||||
factory=MetricsConfig,
|
||||
converter=MetricsConfig,
|
||||
validator=attr.validators.instance_of(MetricsConfig),
|
||||
kw_only=True,
|
||||
)
|
||||
|
||||
observation_noise: ObservationNoise = attr.field(
|
||||
init=True,
|
||||
validator=attr.validators.instance_of(ObservationNoise),
|
||||
on_setattr=attr.setters.validate,
|
||||
default=ObservationNoise.OBSERVATION_NOISE_UNSPECIFIED,
|
||||
kw_only=True,
|
||||
)
|
||||
|
||||
automated_stopping_config: Optional[AutomatedStoppingConfig] = attr.field(
|
||||
init=True,
|
||||
default=None,
|
||||
validator=attr.validators.optional(
|
||||
attr.validators.instance_of(AutomatedStoppingConfig)
|
||||
),
|
||||
on_setattr=attr.setters.validate,
|
||||
kw_only=True,
|
||||
)
|
||||
|
||||
# An internal representation as a StudyConfig proto.
|
||||
# If this object was created from a StudyConfig proto, a copy of the original
|
||||
# proto is kept, to make sure that unknown proto fields are preserved in
|
||||
# round trip serialization.
|
||||
# TODO: Fix the broken proto validation.
|
||||
_study_config: study_pb2.StudySpec = attr.field(
|
||||
init=True, factory=study_pb2.StudySpec, kw_only=True
|
||||
)
|
||||
|
||||
# Public attributes, methods and properties.
|
||||
@classmethod
|
||||
def from_proto(cls, proto: study_pb2.StudySpec) -> "StudyConfig":
|
||||
"""Converts a StudyConfig proto to a StudyConfig object.
|
||||
|
||||
Args:
|
||||
proto: StudyConfig proto.
|
||||
|
||||
Returns:
|
||||
A StudyConfig object.
|
||||
"""
|
||||
metric_information = MetricsConfig(
|
||||
sorted(
|
||||
[MetricInformationConverter.from_proto(m) for m in proto.metrics],
|
||||
key=lambda x: x.name,
|
||||
)
|
||||
)
|
||||
|
||||
oneof_name = proto._pb.WhichOneof("automated_stopping_spec")
|
||||
if not oneof_name:
|
||||
automated_stopping_config = None
|
||||
else:
|
||||
automated_stopping_config = AutomatedStoppingConfig.from_proto(
|
||||
getattr(proto, oneof_name)
|
||||
)
|
||||
|
||||
return cls(
|
||||
search_space=SearchSpace.from_proto(proto),
|
||||
algorithm=Algorithm(proto.algorithm),
|
||||
metric_information=metric_information,
|
||||
observation_noise=ObservationNoise(proto.observation_noise),
|
||||
automated_stopping_config=automated_stopping_config,
|
||||
study_config=copy.deepcopy(proto),
|
||||
)
|
||||
|
||||
def to_proto(self) -> study_pb2.StudySpec:
|
||||
"""Serializes this object to a StudyConfig proto."""
|
||||
proto = copy.deepcopy(self._study_config)
|
||||
proto.algorithm = self.algorithm.value
|
||||
proto.observation_noise = self.observation_noise.value
|
||||
|
||||
del proto.metrics[:]
|
||||
proto.metrics.extend(self.metric_information.to_proto())
|
||||
|
||||
del proto.parameters[:]
|
||||
proto.parameters.extend(self.search_space.parameter_protos)
|
||||
|
||||
if self.automated_stopping_config is not None:
|
||||
auto_stop_proto = self.automated_stopping_config.to_proto()
|
||||
if isinstance(
|
||||
auto_stop_proto, study_pb2.StudySpec.DecayCurveAutomatedStoppingSpec
|
||||
):
|
||||
proto.decay_curve_stopping_spec = copy.deepcopy(auto_stop_proto)
|
||||
elif isinstance(
|
||||
auto_stop_proto, study_pb2.StudySpec.DecayCurveAutomatedStoppingSpec
|
||||
):
|
||||
for method_name in dir(proto.decay_curve_stopping_spec):
|
||||
if callable(
|
||||
getattr(proto.median_automated_stopping_spec, method_name)
|
||||
):
|
||||
print(method_name)
|
||||
proto.median_automated_stopping_spec = copy.deepcopy(auto_stop_proto)
|
||||
|
||||
return proto
|
||||
|
||||
@property
|
||||
def is_single_objective(self) -> bool:
|
||||
"""Returns True if only one objective metric is configured."""
|
||||
return len(self.metric_information) == 1
|
||||
|
||||
@property
|
||||
def single_objective_metric_name(self) -> Optional[str]:
|
||||
"""Returns the name of the single-objective metric, if set.
|
||||
|
||||
Returns:
|
||||
String: name of the single-objective metric.
|
||||
None: if this is not a single-objective study.
|
||||
"""
|
||||
if len(self.metric_information) == 1:
|
||||
return list(self.metric_information)[0].name
|
||||
return None
|
||||
|
||||
def _trial_to_external_values(
|
||||
self, pytrial: Trial
|
||||
) -> Dict[str, Union[float, int, str, bool]]:
|
||||
"""Returns the trial parameter values cast to external types."""
|
||||
parameter_values: Dict[str, Union[float, int, str]] = {}
|
||||
external_values: Dict[str, Union[float, int, str, bool]] = {}
|
||||
# parameter_configs is a list of Tuple[parent_name, ParameterConfig].
|
||||
parameter_configs: List[Tuple[Optional[str], ParameterConfig]] = [
|
||||
(None, p) for p in self.search_space.parameters
|
||||
]
|
||||
remaining_parameters = copy.deepcopy(pytrial.parameters)
|
||||
# Traverse the conditional tree using a BFS.
|
||||
while parameter_configs and remaining_parameters:
|
||||
parent_name, pc = parameter_configs.pop(0)
|
||||
parameter_configs.extend(
|
||||
(pc.name, child) for child in pc.child_parameter_configs
|
||||
)
|
||||
if pc.name not in remaining_parameters:
|
||||
continue
|
||||
if parent_name is not None:
|
||||
# This is a child parameter. If the parent was not seen,
|
||||
# skip this parameter config.
|
||||
if parent_name not in parameter_values:
|
||||
continue
|
||||
parent_value = parameter_values[parent_name]
|
||||
if parent_value not in pc.matching_parent_values:
|
||||
continue
|
||||
parameter_values[pc.name] = remaining_parameters[pc.name].value
|
||||
if pc.external_type is None:
|
||||
external_value = remaining_parameters[pc.name].value
|
||||
else:
|
||||
external_value = remaining_parameters[pc.name].cast(
|
||||
pc.external_type
|
||||
) # pytype: disable=wrong-arg-types
|
||||
external_values[pc.name] = external_value
|
||||
remaining_parameters.pop(pc.name)
|
||||
return external_values
|
||||
|
||||
def trial_parameters(
|
||||
self, proto: study_pb2.Trial
|
||||
) -> Dict[str, ParameterValueSequence]:
|
||||
"""Returns the trial values, cast to external types, if they exist.
|
||||
|
||||
Args:
|
||||
proto:
|
||||
|
||||
Returns:
|
||||
Parameter values dict: cast to each parameter's external_type, if exists.
|
||||
NOTE that the values in the dict may be a Sequence as opposed to a single
|
||||
element.
|
||||
|
||||
Raises:
|
||||
ValueError: If the trial parameters do not exist in this search space.
|
||||
ValueError: If the trial contains duplicate parameters.
|
||||
"""
|
||||
pytrial = proto_converters.TrialConverter.from_proto(proto)
|
||||
return self._pytrial_parameters(pytrial)
|
||||
|
||||
def _pytrial_parameters(self, pytrial: Trial) -> Dict[str, ParameterValueSequence]:
|
||||
"""Returns the trial values, cast to external types, if they exist.
|
||||
|
||||
Args:
|
||||
pytrial:
|
||||
|
||||
Returns:
|
||||
Parameter values dict: cast to each parameter's external_type, if exists.
|
||||
NOTE that the values in the dict may be a Sequence as opposed to a single
|
||||
element.
|
||||
|
||||
Raises:
|
||||
ValueError: If the trial parameters do not exist in this search space.
|
||||
ValueError: If the trial contains duplicate parameters.
|
||||
"""
|
||||
trial_external_values: Dict[
|
||||
str, Union[float, int, str, bool]
|
||||
] = self._trial_to_external_values(pytrial)
|
||||
if len(trial_external_values) != len(pytrial.parameters):
|
||||
raise ValueError(
|
||||
"Invalid trial for this search space: failed to convert "
|
||||
"all trial parameters: {}".format(pytrial)
|
||||
)
|
||||
|
||||
# Combine multi-dimensional parameter values to a list of values.
|
||||
trial_final_values: Dict[str, ParameterValueSequence] = {}
|
||||
# multi_dim_params: Dict[str, List[Tuple[int, ParameterValueSequence]]]
|
||||
multi_dim_params = collections.defaultdict(list)
|
||||
for name in trial_external_values:
|
||||
base_index = SearchSpaceSelector.parse_multi_dimensional_parameter_name(
|
||||
name
|
||||
)
|
||||
if base_index is None:
|
||||
trial_final_values[name] = trial_external_values[name]
|
||||
else:
|
||||
base_name, index = base_index
|
||||
multi_dim_params[base_name].append((index, trial_external_values[name]))
|
||||
for name in multi_dim_params:
|
||||
multi_dim_params[name].sort(key=lambda x: x[0])
|
||||
trial_final_values[name] = [x[1] for x in multi_dim_params[name]]
|
||||
|
||||
return trial_final_values
|
||||
|
||||
def trial_metrics(
|
||||
self, proto: study_pb2.Trial, *, include_all_metrics=False
|
||||
) -> Dict[str, float]:
|
||||
"""Returns the trial's final measurement metric values.
|
||||
|
||||
If the trial is not completed, or infeasible, no metrics are returned.
|
||||
By default, only metrics configured in the StudyConfig are returned
|
||||
(e.g. only objective and safety metrics).
|
||||
|
||||
Args:
|
||||
proto:
|
||||
include_all_metrics: If True, all metrics in the final measurements are
|
||||
returned. If False, only metrics configured in the StudyConfig are
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
Dict[metric name, metric value]
|
||||
"""
|
||||
pytrial = proto_converters.TrialConverter.from_proto(proto)
|
||||
return self._pytrial_metrics(pytrial, include_all_metrics=include_all_metrics)
|
||||
|
||||
def _pytrial_metrics(
|
||||
self, pytrial: Trial, *, include_all_metrics=False
|
||||
) -> Dict[str, float]:
|
||||
"""Returns the trial's final measurement metric values.
|
||||
|
||||
If the trial is not completed, or infeasible, no metrics are returned.
|
||||
By default, only metrics configured in the StudyConfig are returned
|
||||
(e.g. only objective and safety metrics).
|
||||
|
||||
Args:
|
||||
pytrial:
|
||||
include_all_metrics: If True, all metrics in the final measurements are
|
||||
returned. If False, only metrics configured in the StudyConfig are
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
Dict[metric name, metric value]
|
||||
"""
|
||||
configured_metrics = [m.name for m in self.metric_information]
|
||||
|
||||
metrics: Dict[str, float] = {}
|
||||
if pytrial.is_completed and not pytrial.infeasible:
|
||||
for name in pytrial.final_measurement.metrics:
|
||||
if include_all_metrics or (
|
||||
not include_all_metrics and name in configured_metrics
|
||||
):
|
||||
# Special case: Measurement always adds an empty metric by default.
|
||||
# If there is a named single objective in study_config, drop the empty
|
||||
# metric.
|
||||
if not name and self.single_objective_metric_name != name:
|
||||
continue
|
||||
metrics[name] = pytrial.final_measurement.metrics[name].value
|
||||
return metrics
|
||||
@@ -0,0 +1,300 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
from typing import Optional, Collection, Type, TypeVar
|
||||
|
||||
from google.api_core import exceptions
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.vizier import client_abc
|
||||
from google.cloud.aiplatform.vizier import pyvizier as vz
|
||||
from google.cloud.aiplatform.vizier.trial import Trial
|
||||
|
||||
|
||||
from google.cloud.aiplatform.compat.types import study as gca_study
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class Study(base.VertexAiResourceNounWithFutureManager, client_abc.StudyInterface):
|
||||
"""Manage Study resource for Vertex Vizier."""
|
||||
|
||||
client_class = utils.VizierClientWithOverride
|
||||
|
||||
_resource_noun = "study"
|
||||
_getter_method = "get_study"
|
||||
_list_method = "list_studies"
|
||||
_delete_method = "delete_study"
|
||||
_parse_resource_name_method = "parse_study_path"
|
||||
_format_resource_name_method = "study_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
study_id: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing managed study given a study resource name or a study id.
|
||||
|
||||
Example Usage:
|
||||
study = aiplatform.Study(study_id = '12345678')
|
||||
or
|
||||
study = aiplatform.Study(study_id = 'projects/123/locations/us-central1/studies/12345678')
|
||||
|
||||
Args:
|
||||
study_id (str):
|
||||
Required. A fully-qualified study resource name or a study ID.
|
||||
Example: "projects/123/locations/us-central1/studies/12345678" or "12345678" when
|
||||
project and location are initialized or passed.
|
||||
project (str):
|
||||
Optional. Project to retrieve study from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve study from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Feature. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
base.VertexAiResourceNounWithFutureManager.__init__(
|
||||
self,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
resource_name=study_id,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(resource_name=study_id)
|
||||
|
||||
@classmethod
|
||||
@base.optional_sync()
|
||||
def create_or_load(
|
||||
cls,
|
||||
display_name: str,
|
||||
problem: vz.ProblemStatement,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> client_abc.StudyInterface:
|
||||
"""Creates a Study resource.
|
||||
|
||||
Example Usage:
|
||||
sc = pyvizier.StudyConfig()
|
||||
sc.algorithm = pyvizier.Algorithm.RANDOM_SEARCH
|
||||
sc.metric_information.append(
|
||||
pyvizier.MetricInformation(
|
||||
name='pr-auc', goal=pyvizier.ObjectiveMetricGoal.MAXIMIZE))
|
||||
root = sc.search_space.select_root()
|
||||
root.add_float_param(
|
||||
'learning_rate', 0.00001, 1.0, scale_type=pyvizier.ScaleType.LINEAR)
|
||||
root.add_categorical_param('optimizer', ['adagrad', 'adam', 'experimental'])
|
||||
study = aiplatform.Study.create_or_load(display_name='tuning_study', problem=sc)
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Required. A name to describe the Study. It's unique per study. An existing study
|
||||
will be returned if the study has the same display name.
|
||||
problem (vz.ProblemStatement):
|
||||
Required. Configurations of the study. It defines the problem to create the study.
|
||||
project (str):
|
||||
Optional. Project to retrieve study from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve study from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Feature. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
StudyInterface - The created study resource object.
|
||||
"""
|
||||
project = initializer.global_config.project if not project else project
|
||||
location = initializer.global_config.location if not location else location
|
||||
credentials = (
|
||||
initializer.global_config.credentials if not credentials else credentials
|
||||
)
|
||||
|
||||
api_client = cls._instantiate_client(
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
study = gca_study.Study(
|
||||
display_name=display_name, study_spec=problem.to_proto()
|
||||
)
|
||||
|
||||
try:
|
||||
study = api_client.create_study(
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project,
|
||||
location,
|
||||
),
|
||||
study=study,
|
||||
)
|
||||
except exceptions.AlreadyExists:
|
||||
_LOGGER.info("The study is already created. Using existing study.")
|
||||
study = api_client.lookup_study(
|
||||
request={
|
||||
"parent": initializer.global_config.common_location_path(
|
||||
project,
|
||||
location,
|
||||
),
|
||||
"display_name": display_name,
|
||||
},
|
||||
)
|
||||
|
||||
return Study(study.name)
|
||||
|
||||
def get_trial(self, uid: int) -> client_abc.TrialInterface:
|
||||
"""Retrieves the trial under the study by given trial id.
|
||||
|
||||
Args:
|
||||
uid (int): Required. Unique identifier of the trial to search.
|
||||
Returns:
|
||||
TrialInterface - The trial resource object.
|
||||
"""
|
||||
study_path_components = self._parse_resource_name(self.resource_name)
|
||||
return Trial(
|
||||
Trial._format_resource_name(
|
||||
project=study_path_components["project"],
|
||||
location=study_path_components["location"],
|
||||
study=study_path_components["study"],
|
||||
trial=uid,
|
||||
),
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
def trials(
|
||||
self, trial_filter: Optional[vz.TrialFilter] = None
|
||||
) -> Collection[client_abc.TrialInterface]:
|
||||
"""Fetches a collection of trials.
|
||||
|
||||
Args:
|
||||
trial_filter (int): Optional. A filter for the trials.
|
||||
Returns:
|
||||
Collection[TrialInterface] - A list of trials resource object belonging
|
||||
to the study.
|
||||
"""
|
||||
list_trials_request = {"parent": self.resource_name}
|
||||
trials_response = self.api_client.list_trials(request=list_trials_request)
|
||||
return [
|
||||
Trial._construct_sdk_resource_from_gapic(
|
||||
trial,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
for trial in trials_response.trials
|
||||
]
|
||||
|
||||
def optimal_trials(self) -> Collection[client_abc.TrialInterface]:
|
||||
"""Returns optimal trial(s).
|
||||
|
||||
Returns:
|
||||
Collection[TrialInterface] - A list of optimal trials resource object.
|
||||
"""
|
||||
list_optimal_trials_request = {"parent": self.resource_name}
|
||||
optimal_trials_response = self.api_client.list_optimal_trials(
|
||||
request=list_optimal_trials_request
|
||||
)
|
||||
return [
|
||||
Trial._construct_sdk_resource_from_gapic(
|
||||
trial,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
for trial in optimal_trials_response.optimal_trials
|
||||
]
|
||||
|
||||
def materialize_study_config(self) -> vz.StudyConfig:
|
||||
"""#Materializes the study config.
|
||||
|
||||
Returns:
|
||||
StudyConfig - A deepcopy of StudyConfig from the study.
|
||||
"""
|
||||
study = self.api_client.get_study(
|
||||
name=self.resource_name, credentials=self.credentials
|
||||
)
|
||||
return copy.deepcopy(vz.StudyConfig.from_proto(study.study_spec))
|
||||
|
||||
@classmethod
|
||||
def from_uid(
|
||||
cls: Type[_T],
|
||||
uid: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> _T:
|
||||
"""Fetches an existing study from the Vizier service.
|
||||
|
||||
Args:
|
||||
uid (str): Required. Unique identifier of the study.
|
||||
Returns:
|
||||
StudyInterface - The study resource object.
|
||||
"""
|
||||
project = initializer.global_config.project if not project else project
|
||||
location = initializer.global_config.location if not location else location
|
||||
credentials = (
|
||||
initializer.global_config.credentials if not credentials else credentials
|
||||
)
|
||||
|
||||
return Study(
|
||||
study_id=uid, project=project, location=location, credentials=credentials
|
||||
)
|
||||
|
||||
def suggest(
|
||||
self, *, count: Optional[int] = None, worker: str = ""
|
||||
) -> Collection[client_abc.TrialInterface]:
|
||||
"""Returns Trials to be evaluated by worker.
|
||||
|
||||
Args:
|
||||
count (int): Optional. Number of suggestions.
|
||||
worker (str): When new Trials are generated, their `assigned_worker` field is
|
||||
populated with this worker. suggest() first looks for existing Trials
|
||||
that are assigned to `worker`, before generating new ones.
|
||||
Returns:
|
||||
Collection[TrialInterface] - A list of suggested trial resource objects.
|
||||
"""
|
||||
suggest_trials_lro = self.api_client.suggest_trials(
|
||||
request={
|
||||
"parent": self.resource_name,
|
||||
"suggestion_count": count,
|
||||
"client_id": worker,
|
||||
},
|
||||
)
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"Suggest", "study", self.__class__, suggest_trials_lro
|
||||
)
|
||||
_LOGGER.info(self.client_class.get_gapic_client_class())
|
||||
trials = suggest_trials_lro.result()
|
||||
_LOGGER.log_action_completed_against_resource("study", "suggested", self)
|
||||
return [
|
||||
Trial._construct_sdk_resource_from_gapic(
|
||||
trial,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
for trial in trials.trials
|
||||
]
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Deletes the study."""
|
||||
self.api_client.delete_study(name=self.resource_name)
|
||||
@@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
|
||||
from typing import Optional, TypeVar, Mapping, Any
|
||||
from google.cloud.aiplatform.vizier.client_abc import TrialInterface
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.vizier import study
|
||||
from google.cloud.aiplatform.vizier import pyvizier as vz
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class Trial(base.VertexAiResourceNounWithFutureManager, TrialInterface):
|
||||
"""Manage Trial resource for Vertex Vizier."""
|
||||
|
||||
client_class = utils.VizierClientWithOverride
|
||||
|
||||
_resource_noun = "trial"
|
||||
_getter_method = "get_trial"
|
||||
_list_method = "list_trials"
|
||||
_delete_method = "delete_trial"
|
||||
_parse_resource_name_method = "parse_trial_path"
|
||||
_format_resource_name_method = "trial_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trial_name: str,
|
||||
study_id: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing managed trial given a trial resource name or a study id.
|
||||
|
||||
Example Usage:
|
||||
trial = aiplatform.Trial(trial_name = 'projects/123/locations/us-central1/studies/12345678/trials/1')
|
||||
or
|
||||
trial = aiplatform.Trial(trial_name = '1', study_id = '12345678')
|
||||
|
||||
Args:
|
||||
trial_name (str):
|
||||
Required. A fully-qualified trial resource name or a trial ID.
|
||||
Example: "projects/123/locations/us-central1/studies/12345678/trials/1" or "12345678" when
|
||||
project and location are initialized or passed.
|
||||
study_id (str):
|
||||
Optional. A fully-qualified study resource name or a study ID.
|
||||
Example: "projects/123/locations/us-central1/studies/12345678" or "12345678" when
|
||||
project and location are initialized or passed.
|
||||
project (str):
|
||||
Optional. Project to retrieve trial from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve trial from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Feature. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
base.VertexAiResourceNounWithFutureManager.__init__(
|
||||
self,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
resource_name=trial_name,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(
|
||||
resource_name=trial_name,
|
||||
parent_resource_name_fields={
|
||||
study.Study._resource_noun: study_id,
|
||||
}
|
||||
if study_id
|
||||
else study_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def uid(self) -> int:
|
||||
"""Unique identifier of the trial."""
|
||||
trial_path_components = self._parse_resource_name(self.resource_name)
|
||||
return int(trial_path_components["trial"])
|
||||
|
||||
@property
|
||||
def parameters(self) -> Mapping[str, Any]:
|
||||
"""Parameters of the trial."""
|
||||
trial = self.api_client.get_trial(name=self.resource_name)
|
||||
return vz.TrialConverter.from_proto(trial).parameters
|
||||
|
||||
@property
|
||||
def status(self) -> vz.TrialStatus:
|
||||
"""Status of the Trial."""
|
||||
trial = self.api_client.get_trial(name=self.resource_name)
|
||||
return vz.TrialConverter.from_proto(trial).status
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Deletes the Trial in Vizier service."""
|
||||
self.api_client.delete_trial(name=self.resource_name)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
measurement: Optional[vz.Measurement] = None,
|
||||
*,
|
||||
infeasible_reason: Optional[str] = None
|
||||
) -> Optional[vz.Measurement]:
|
||||
"""Completes the trial and #materializes the measurement.
|
||||
|
||||
* If `measurement` is provided, then Vizier writes it as the trial's final
|
||||
measurement and returns it.
|
||||
* If `infeasible_reason` is provided, `measurement` is not needed.
|
||||
* If neither is provided, then Vizier selects an existing (intermediate)
|
||||
measurement to be the final measurement and returns it.
|
||||
|
||||
Args:
|
||||
measurement: Final measurement.
|
||||
infeasible_reason: Indefeasibly reason for missing final measurement.
|
||||
"""
|
||||
complete_trial_request = {"name": self.resource_name}
|
||||
if infeasible_reason is not None:
|
||||
complete_trial_request["infeasible_reason"] = infeasible_reason
|
||||
complete_trial_request["trial_infeasible"] = True
|
||||
if measurement is not None:
|
||||
complete_trial_request[
|
||||
"final_measurement"
|
||||
] = vz.MeasurementConverter.to_proto(measurement)
|
||||
trial = self.api_client.complete_trial(request=complete_trial_request)
|
||||
return (
|
||||
vz.MeasurementConverter.from_proto(trial.final_measurement)
|
||||
if trial.final_measurement
|
||||
else None
|
||||
)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
"""Returns true if the Trial should stop."""
|
||||
check_trial_early_stopping_state_request = {"trial_name": self.resource_name}
|
||||
should_stop_lro = self.api_client.check_trial_early_stopping_state(
|
||||
request=check_trial_early_stopping_state_request
|
||||
)
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"ShouldStop", "trial", self.__class__, should_stop_lro
|
||||
)
|
||||
should_stop_lro.result()
|
||||
_LOGGER.log_action_completed_against_resource("trial", "should_stop", self)
|
||||
return should_stop_lro.result().should_stop
|
||||
|
||||
def add_measurement(self, measurement: vz.Measurement) -> None:
|
||||
"""Adds an intermediate measurement."""
|
||||
add_trial_measurement_request = {
|
||||
"trial_name": self.resource_name,
|
||||
}
|
||||
add_trial_measurement_request["measurement"] = vz.MeasurementConverter.to_proto(
|
||||
measurement
|
||||
)
|
||||
self.api_client.add_trial_measurement(request=add_trial_measurement_request)
|
||||
|
||||
def materialize(self, *, include_all_measurements: bool = True) -> vz.Trial:
|
||||
"""#Materializes the Trial.
|
||||
|
||||
Args:
|
||||
include_all_measurements: If True, returned Trial includes all
|
||||
intermediate measurements. The final measurement is always provided.
|
||||
"""
|
||||
trial = self.api_client.get_trial(name=self.resource_name)
|
||||
return copy.deepcopy(vz.TrialConverter.from_proto(trial))
|
||||
Reference in New Issue
Block a user