# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Common utilities for the SDK.""" import base64 import datetime import enum import functools import typing from typing import Any, Callable, Optional, Union import uuid import warnings import pydantic from pydantic import alias_generators from . import _api_client from . import errors def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None: """Examples: set_value_by_path({}, ['a', 'b'], v) -> {'a': {'b': v}} set_value_by_path({}, ['a', 'b[]', c], [v1, v2]) -> {'a': {'b': [{'c': v1}, {'c': v2}]}} set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3) -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}} """ if value is None: return for i, key in enumerate(keys[:-1]): if key.endswith('[]'): key_name = key[:-2] if data is not None and key_name not in data: if isinstance(value, list): data[key_name] = [{} for _ in range(len(value))] else: raise ValueError( f'value {value} must be a list given an array path {key}' ) if isinstance(value, list) and data is not None: for j, d in enumerate(data[key_name]): set_value_by_path(d, keys[i + 1 :], value[j]) else: if data is not None: for d in data[key_name]: set_value_by_path(d, keys[i + 1 :], value) return elif key.endswith('[0]'): key_name = key[:-3] if data is not None and key_name not in data: data[key_name] = [{}] if data is not None: set_value_by_path(data[key_name][0], keys[i + 1 :], value) return if data is not None: data = data.setdefault(key, {}) if data is not None: existing_data = data.get(keys[-1]) # If there is an existing value, merge, not overwrite. if existing_data is not None: # Don't overwrite existing non-empty value with new empty value. # This is triggered when handling tuning datasets. if not value: pass # Don't fail when overwriting value with same value elif value == existing_data: pass # Instead of overwriting dictionary with another dictionary, merge them. # This is important for handling training and validation datasets in tuning. elif isinstance(existing_data, dict) and isinstance(value, dict): # Merging dictionaries. Consider deep merging in the future. existing_data.update(value) else: raise ValueError( f'Cannot set value for an existing key. Key: {keys[-1]};' f' Existing value: {existing_data}; New value: {value}.' ) else: data[keys[-1]] = value def get_value_by_path(data: Any, keys: list[str]) -> Any: """Examples: get_value_by_path({'a': {'b': v}}, ['a', 'b']) -> v get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c']) -> [v1, v2] """ if keys == ['_self']: return data for i, key in enumerate(keys): if not data: return None if key.endswith('[]'): key_name = key[:-2] if key_name in data: return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]] else: return None elif key.endswith('[0]'): key_name = key[:-3] if key_name in data and data[key_name]: return get_value_by_path(data[key_name][0], keys[i + 1 :]) else: return None else: if key in data: data = data[key] elif isinstance(data, BaseModel) and hasattr(data, key): data = getattr(data, key) else: return None return data def convert_to_dict(obj: object) -> Any: """Recursively converts a given object to a dictionary. If the object is a Pydantic model, it uses the model's `model_dump()` method. Args: obj: The object to convert. Returns: A dictionary representation of the object, a list of objects if a list is passed, or the object itself if it is not a dictionary, list, or Pydantic model. """ if isinstance(obj, pydantic.BaseModel): return obj.model_dump(exclude_none=True) elif isinstance(obj, dict): return {key: convert_to_dict(value) for key, value in obj.items()} elif isinstance(obj, list): return [convert_to_dict(item) for item in obj] else: return obj def _remove_extra_fields( model: Any, response: dict[str, object] ) -> None: """Removes extra fields from the response that are not in the model. Mutates the response in place. """ key_values = list(response.items()) for key, value in key_values: # Need to convert to snake case to match model fields names # ex: UsageMetadata alias_map = { field_info.alias: key for key, field_info in model.model_fields.items() } if key not in model.model_fields and key not in alias_map: response.pop(key) continue key = alias_map.get(key, key) annotation = model.model_fields[key].annotation # Get the BaseModel if Optional if typing.get_origin(annotation) is Union: annotation = typing.get_args(annotation)[0] # if dict, assume BaseModel but also check that field type is not dict # example: FunctionCall.args if isinstance(value, dict) and typing.get_origin(annotation) is not dict: _remove_extra_fields(annotation, value) elif isinstance(value, list): for item in value: # assume a list of dict is list of BaseModel if isinstance(item, dict): _remove_extra_fields(typing.get_args(annotation)[0], item) T = typing.TypeVar('T', bound='BaseModel') class BaseModel(pydantic.BaseModel): model_config = pydantic.ConfigDict( alias_generator=alias_generators.to_camel, populate_by_name=True, from_attributes=True, protected_namespaces=(), extra='forbid', # This allows us to use arbitrary types in the model. E.g. PIL.Image. arbitrary_types_allowed=True, ser_json_bytes='base64', val_json_bytes='base64', ignored_types=(typing.TypeVar,) ) @classmethod def _from_response( cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object] ) -> T: # To maintain forward compatibility, we need to remove extra fields from # the response. # We will provide another mechanism to allow users to access these fields. _remove_extra_fields(cls, response) validated_response = cls.model_validate(response) return validated_response def to_json_dict(self) -> dict[str, object]: return self.model_dump(exclude_none=True, mode='json') class CaseInSensitiveEnum(str, enum.Enum): """Case insensitive enum.""" @classmethod def _missing_(cls, value: Any) -> Any: try: return cls[value.upper()] # Try to access directly with uppercase except KeyError: try: return cls[value.lower()] # Try to access directly with lowercase except KeyError: warnings.warn(f"{value} is not a valid {cls.__name__}") try: # Creating a enum instance based on the value # We need to use super() to avoid infinite recursion. unknown_enum_val = super().__new__(cls, value) unknown_enum_val._name_ = str(value) # pylint: disable=protected-access unknown_enum_val._value_ = value # pylint: disable=protected-access return unknown_enum_val except: return None def timestamped_unique_name() -> str: """Composes a timestamped unique name. Returns: A string representing a unique name. """ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') unique_id = uuid.uuid4().hex[0:5] return f'{timestamp}_{unique_id}' def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]: """Converts unserializable types in dict to json.dumps() compatible types. This function is called in models.py after calling convert_to_dict(). The convert_to_dict() can convert pydantic object to dict. However, the input to convert_to_dict() is dict mixed of pydantic object and nested dict(the output of converters). So they may be bytes in the dict and they are out of `ser_json_bytes` control in model_dump(mode='json') called in `convert_to_dict`, as well as datetime deserialization in Pydantic json mode. Returns: A dictionary with json.dumps() incompatible type (e.g. bytes datetime) to compatible type (e.g. base64 encoded string, isoformat date string). """ processed_data: dict[str, object] = {} if not isinstance(data, dict): return data for key, value in data.items(): if isinstance(value, bytes): processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii') elif isinstance(value, datetime.datetime): processed_data[key] = value.isoformat() elif isinstance(value, dict): processed_data[key] = encode_unserializable_types(value) elif isinstance(value, list): if all(isinstance(v, bytes) for v in value): processed_data[key] = [ base64.urlsafe_b64encode(v).decode('ascii') for v in value ] if all(isinstance(v, datetime.datetime) for v in value): processed_data[key] = [v.isoformat() for v in value] else: processed_data[key] = [encode_unserializable_types(v) for v in value] else: processed_data[key] = value return processed_data def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Experimental warning, only warns once.""" def decorator(func: Callable[..., Any]) -> Callable[..., Any]: warning_done = False @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal warning_done if not warning_done: warning_done = True warnings.warn( message=message, category=errors.ExperimentalWarning, stacklevel=2, ) return func(*args, **kwargs) return wrapper return decorator