Files
evo-ai/.venv/lib/python3.10/site-packages/google/genai/_common.py
2025-04-25 15:30:54 -03:00

319 lines
10 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utilities for the SDK."""
import base64
import datetime
import enum
import functools
import typing
from typing import Any, Callable, Optional, Union
import uuid
import warnings
import pydantic
from pydantic import alias_generators
from . import _api_client
from . import errors
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
"""Examples:
set_value_by_path({}, ['a', 'b'], v)
-> {'a': {'b': v}}
set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
-> {'a': {'b': [{'c': v1}, {'c': v2}]}}
set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
-> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
"""
if value is None:
return
for i, key in enumerate(keys[:-1]):
if key.endswith('[]'):
key_name = key[:-2]
if data is not None and key_name not in data:
if isinstance(value, list):
data[key_name] = [{} for _ in range(len(value))]
else:
raise ValueError(
f'value {value} must be a list given an array path {key}'
)
if isinstance(value, list) and data is not None:
for j, d in enumerate(data[key_name]):
set_value_by_path(d, keys[i + 1 :], value[j])
else:
if data is not None:
for d in data[key_name]:
set_value_by_path(d, keys[i + 1 :], value)
return
elif key.endswith('[0]'):
key_name = key[:-3]
if data is not None and key_name not in data:
data[key_name] = [{}]
if data is not None:
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
return
if data is not None:
data = data.setdefault(key, {})
if data is not None:
existing_data = data.get(keys[-1])
# If there is an existing value, merge, not overwrite.
if existing_data is not None:
# Don't overwrite existing non-empty value with new empty value.
# This is triggered when handling tuning datasets.
if not value:
pass
# Don't fail when overwriting value with same value
elif value == existing_data:
pass
# Instead of overwriting dictionary with another dictionary, merge them.
# This is important for handling training and validation datasets in tuning.
elif isinstance(existing_data, dict) and isinstance(value, dict):
# Merging dictionaries. Consider deep merging in the future.
existing_data.update(value)
else:
raise ValueError(
f'Cannot set value for an existing key. Key: {keys[-1]};'
f' Existing value: {existing_data}; New value: {value}.'
)
else:
data[keys[-1]] = value
def get_value_by_path(data: Any, keys: list[str]) -> Any:
"""Examples:
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
-> v
get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
-> [v1, v2]
"""
if keys == ['_self']:
return data
for i, key in enumerate(keys):
if not data:
return None
if key.endswith('[]'):
key_name = key[:-2]
if key_name in data:
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
else:
return None
elif key.endswith('[0]'):
key_name = key[:-3]
if key_name in data and data[key_name]:
return get_value_by_path(data[key_name][0], keys[i + 1 :])
else:
return None
else:
if key in data:
data = data[key]
elif isinstance(data, BaseModel) and hasattr(data, key):
data = getattr(data, key)
else:
return None
return data
def convert_to_dict(obj: object) -> Any:
"""Recursively converts a given object to a dictionary.
If the object is a Pydantic model, it uses the model's `model_dump()` method.
Args:
obj: The object to convert.
Returns:
A dictionary representation of the object, a list of objects if a list is
passed, or the object itself if it is not a dictionary, list, or Pydantic
model.
"""
if isinstance(obj, pydantic.BaseModel):
return obj.model_dump(exclude_none=True)
elif isinstance(obj, dict):
return {key: convert_to_dict(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_to_dict(item) for item in obj]
else:
return obj
def _remove_extra_fields(
model: Any, response: dict[str, object]
) -> None:
"""Removes extra fields from the response that are not in the model.
Mutates the response in place.
"""
key_values = list(response.items())
for key, value in key_values:
# Need to convert to snake case to match model fields names
# ex: UsageMetadata
alias_map = {
field_info.alias: key for key, field_info in model.model_fields.items()
}
if key not in model.model_fields and key not in alias_map:
response.pop(key)
continue
key = alias_map.get(key, key)
annotation = model.model_fields[key].annotation
# Get the BaseModel if Optional
if typing.get_origin(annotation) is Union:
annotation = typing.get_args(annotation)[0]
# if dict, assume BaseModel but also check that field type is not dict
# example: FunctionCall.args
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
_remove_extra_fields(annotation, value)
elif isinstance(value, list):
for item in value:
# assume a list of dict is list of BaseModel
if isinstance(item, dict):
_remove_extra_fields(typing.get_args(annotation)[0], item)
T = typing.TypeVar('T', bound='BaseModel')
class BaseModel(pydantic.BaseModel):
model_config = pydantic.ConfigDict(
alias_generator=alias_generators.to_camel,
populate_by_name=True,
from_attributes=True,
protected_namespaces=(),
extra='forbid',
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
arbitrary_types_allowed=True,
ser_json_bytes='base64',
val_json_bytes='base64',
ignored_types=(typing.TypeVar,)
)
@classmethod
def _from_response(
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
) -> T:
# To maintain forward compatibility, we need to remove extra fields from
# the response.
# We will provide another mechanism to allow users to access these fields.
_remove_extra_fields(cls, response)
validated_response = cls.model_validate(response)
return validated_response
def to_json_dict(self) -> dict[str, object]:
return self.model_dump(exclude_none=True, mode='json')
class CaseInSensitiveEnum(str, enum.Enum):
"""Case insensitive enum."""
@classmethod
def _missing_(cls, value: Any) -> Any:
try:
return cls[value.upper()] # Try to access directly with uppercase
except KeyError:
try:
return cls[value.lower()] # Try to access directly with lowercase
except KeyError:
warnings.warn(f"{value} is not a valid {cls.__name__}")
try:
# Creating a enum instance based on the value
# We need to use super() to avoid infinite recursion.
unknown_enum_val = super().__new__(cls, value)
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
unknown_enum_val._value_ = value # pylint: disable=protected-access
return unknown_enum_val
except:
return None
def timestamped_unique_name() -> str:
"""Composes a timestamped unique name.
Returns:
A string representing a unique name.
"""
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
unique_id = uuid.uuid4().hex[0:5]
return f'{timestamp}_{unique_id}'
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
"""Converts unserializable types in dict to json.dumps() compatible types.
This function is called in models.py after calling convert_to_dict(). The
convert_to_dict() can convert pydantic object to dict. However, the input to
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
of converters). So they may be bytes in the dict and they are out of
`ser_json_bytes` control in model_dump(mode='json') called in
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
Returns:
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
to compatible type (e.g. base64 encoded string, isoformat date string).
"""
processed_data: dict[str, object] = {}
if not isinstance(data, dict):
return data
for key, value in data.items():
if isinstance(value, bytes):
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
elif isinstance(value, datetime.datetime):
processed_data[key] = value.isoformat()
elif isinstance(value, dict):
processed_data[key] = encode_unserializable_types(value)
elif isinstance(value, list):
if all(isinstance(v, bytes) for v in value):
processed_data[key] = [
base64.urlsafe_b64encode(v).decode('ascii') for v in value
]
if all(isinstance(v, datetime.datetime) for v in value):
processed_data[key] = [v.isoformat() for v in value]
else:
processed_data[key] = [encode_unserializable_types(v) for v in value]
else:
processed_data[key] = value
return processed_data
def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Experimental warning, only warns once."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
warning_done = False
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal warning_done
if not warning_done:
warning_done = True
warnings.warn(
message=message,
category=errors.ExperimentalWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
return decorator