1114 lines
34 KiB
Python
1114 lines
34 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""Transformers for Google GenAI SDK."""
|
|
|
|
import base64
|
|
from collections.abc import Iterable, Mapping
|
|
from enum import Enum, EnumMeta
|
|
import inspect
|
|
import io
|
|
import logging
|
|
import re
|
|
import sys
|
|
import time
|
|
import types as builtin_types
|
|
import typing
|
|
from typing import Any, GenericAlias, Optional, Sequence, Union # type: ignore[attr-defined]
|
|
|
|
if typing.TYPE_CHECKING:
|
|
import PIL.Image
|
|
|
|
import pydantic
|
|
|
|
from . import _api_client
|
|
from . import types
|
|
|
|
logger = logging.getLogger('google_genai._transformers')
|
|
|
|
if sys.version_info >= (3, 10):
|
|
VersionedUnionType = builtin_types.UnionType
|
|
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
|
from typing import TypeGuard
|
|
else:
|
|
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
|
_UNION_TYPES = (typing.Union,)
|
|
from typing_extensions import TypeGuard
|
|
|
|
def _resource_name(
|
|
client: _api_client.BaseApiClient,
|
|
resource_name: str,
|
|
*,
|
|
collection_identifier: str,
|
|
collection_hierarchy_depth: int = 2,
|
|
) -> str:
|
|
# pylint: disable=line-too-long
|
|
"""Prepends resource name with project, location, collection_identifier if needed.
|
|
|
|
The collection_identifier will only be prepended if it's not present
|
|
and the prepending won't violate the collection hierarchy depth.
|
|
When the prepending condition doesn't meet, returns the input
|
|
resource_name.
|
|
|
|
Args:
|
|
client: The API client.
|
|
resource_name: The user input resource name to be completed.
|
|
collection_identifier: The collection identifier to be prepended. See
|
|
collection identifiers in https://google.aip.dev/122.
|
|
collection_hierarchy_depth: The collection hierarchy depth. Only set this
|
|
field when the resource has nested collections. For example,
|
|
`users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
|
|
`users` and collection_hierarchy_depth is 4. See nested collections in
|
|
https://google.aip.dev/122.
|
|
|
|
Example:
|
|
|
|
resource_name = 'cachedContents/123'
|
|
client.vertexai = True
|
|
client.project = 'bar'
|
|
client.location = 'us-west1'
|
|
_resource_name(client, 'cachedContents/123',
|
|
collection_identifier='cachedContents')
|
|
returns: 'projects/bar/locations/us-west1/cachedContents/123'
|
|
|
|
Example:
|
|
|
|
resource_name = 'projects/foo/locations/us-central1/cachedContents/123'
|
|
# resource_name = 'locations/us-central1/cachedContents/123'
|
|
client.vertexai = True
|
|
client.project = 'bar'
|
|
client.location = 'us-west1'
|
|
_resource_name(client, resource_name,
|
|
collection_identifier='cachedContents')
|
|
returns: 'projects/foo/locations/us-central1/cachedContents/123'
|
|
|
|
Example:
|
|
|
|
resource_name = '123'
|
|
# resource_name = 'cachedContents/123'
|
|
client.vertexai = False
|
|
_resource_name(client, resource_name,
|
|
collection_identifier='cachedContents')
|
|
returns 'cachedContents/123'
|
|
|
|
Example:
|
|
resource_name = 'some/wrong/cachedContents/resource/name/123'
|
|
resource_prefix = 'cachedContents'
|
|
client.vertexai = False
|
|
# client.vertexai = True
|
|
_resource_name(client, resource_name,
|
|
collection_identifier='cachedContents')
|
|
returns: 'some/wrong/cachedContents/resource/name/123'
|
|
|
|
Returns:
|
|
The completed resource name.
|
|
"""
|
|
should_prepend_collection_identifier = (
|
|
not resource_name.startswith(f'{collection_identifier}/')
|
|
# Check if prepending the collection identifier won't violate the
|
|
# collection hierarchy depth.
|
|
and f'{collection_identifier}/{resource_name}'.count('/') + 1
|
|
== collection_hierarchy_depth
|
|
)
|
|
if client.vertexai:
|
|
if resource_name.startswith('projects/'):
|
|
return resource_name
|
|
elif resource_name.startswith('locations/'):
|
|
return f'projects/{client.project}/{resource_name}'
|
|
elif resource_name.startswith(f'{collection_identifier}/'):
|
|
return f'projects/{client.project}/locations/{client.location}/{resource_name}'
|
|
elif should_prepend_collection_identifier:
|
|
return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}'
|
|
else:
|
|
return resource_name
|
|
else:
|
|
if should_prepend_collection_identifier:
|
|
return f'{collection_identifier}/{resource_name}'
|
|
else:
|
|
return resource_name
|
|
|
|
|
|
def t_model(client: _api_client.BaseApiClient, model: str) -> str:
|
|
if not model:
|
|
raise ValueError('model is required.')
|
|
if client.vertexai:
|
|
if (
|
|
model.startswith('projects/')
|
|
or model.startswith('models/')
|
|
or model.startswith('publishers/')
|
|
):
|
|
return model
|
|
elif '/' in model:
|
|
publisher, model_id = model.split('/', 1)
|
|
return f'publishers/{publisher}/models/{model_id}'
|
|
else:
|
|
return f'publishers/google/models/{model}'
|
|
else:
|
|
if model.startswith('models/'):
|
|
return model
|
|
elif model.startswith('tunedModels/'):
|
|
return model
|
|
else:
|
|
return f'models/{model}'
|
|
|
|
|
|
def t_models_url(
|
|
api_client: _api_client.BaseApiClient, base_models: bool
|
|
) -> str:
|
|
if api_client.vertexai:
|
|
if base_models:
|
|
return 'publishers/google/models'
|
|
else:
|
|
return 'models'
|
|
else:
|
|
if base_models:
|
|
return 'models'
|
|
else:
|
|
return 'tunedModels'
|
|
|
|
|
|
def t_extract_models(
|
|
api_client: _api_client.BaseApiClient,
|
|
response: dict[str, Any],
|
|
) -> list[dict[str, Any]]:
|
|
if not response:
|
|
return []
|
|
|
|
models: Optional[list[dict[str, Any]]] = response.get('models')
|
|
if models is not None:
|
|
return models
|
|
|
|
tuned_models: Optional[list[dict[str, Any]]] = response.get('tunedModels')
|
|
if tuned_models is not None:
|
|
return tuned_models
|
|
|
|
publisher_models: Optional[list[dict[str, Any]]] = response.get(
|
|
'publisherModels'
|
|
)
|
|
if publisher_models is not None:
|
|
return publisher_models
|
|
|
|
if (
|
|
response.get('httpHeaders') is not None
|
|
and response.get('jsonPayload') is None
|
|
):
|
|
return []
|
|
else:
|
|
logger.warning('Cannot determine the models type.')
|
|
logger.debug('Cannot determine the models type for response: %s', response)
|
|
return []
|
|
|
|
|
|
def t_caches_model(api_client: _api_client.BaseApiClient, model: str) -> Optional[str]:
|
|
model = t_model(api_client, model)
|
|
if not model:
|
|
return None
|
|
if model.startswith('publishers/') and api_client.vertexai:
|
|
# vertex caches only support model name start with projects.
|
|
return (
|
|
f'projects/{api_client.project}/locations/{api_client.location}/{model}'
|
|
)
|
|
elif model.startswith('models/') and api_client.vertexai:
|
|
return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}'
|
|
else:
|
|
return model
|
|
|
|
|
|
def pil_to_blob(img: Any) -> types.Blob:
|
|
PngImagePlugin: Optional[builtin_types.ModuleType]
|
|
try:
|
|
import PIL.PngImagePlugin
|
|
|
|
PngImagePlugin = PIL.PngImagePlugin
|
|
except ImportError:
|
|
PngImagePlugin = None
|
|
|
|
bytesio = io.BytesIO()
|
|
if (
|
|
PngImagePlugin is not None
|
|
and isinstance(img, PngImagePlugin.PngImageFile)
|
|
or img.mode == 'RGBA'
|
|
):
|
|
img.save(bytesio, format='PNG')
|
|
mime_type = 'image/png'
|
|
else:
|
|
img.save(bytesio, format='JPEG')
|
|
mime_type = 'image/jpeg'
|
|
bytesio.seek(0)
|
|
data = bytesio.read()
|
|
return types.Blob(mime_type=mime_type, data=data)
|
|
|
|
|
|
def t_function_response(
|
|
function_response: types.FunctionResponseOrDict,
|
|
) -> types.FunctionResponse:
|
|
if not function_response:
|
|
raise ValueError('function_response is required.')
|
|
if isinstance(function_response, dict):
|
|
return types.FunctionResponse.model_validate(function_response)
|
|
elif isinstance(function_response, types.FunctionResponse):
|
|
return function_response
|
|
else:
|
|
raise TypeError(
|
|
f'Could not parse input as FunctionResponse. Unsupported'
|
|
f' function_response type: {type(function_response)}'
|
|
)
|
|
|
|
def t_function_responses(
|
|
function_responses: Union[
|
|
types.FunctionResponseOrDict,
|
|
Sequence[types.FunctionResponseOrDict],
|
|
],
|
|
) -> list[types.FunctionResponse]:
|
|
if not function_responses:
|
|
raise ValueError('function_responses are required.')
|
|
if isinstance(function_responses, Sequence):
|
|
return [t_function_response(response) for response in function_responses]
|
|
else:
|
|
return [t_function_response(function_responses)]
|
|
|
|
|
|
def t_blobs(
|
|
api_client: _api_client.BaseApiClient,
|
|
blobs: Union[types.BlobImageUnionDict, list[types.BlobImageUnionDict]],
|
|
) -> list[types.Blob]:
|
|
if isinstance(blobs, list):
|
|
return [t_blob(api_client, blob) for blob in blobs]
|
|
else:
|
|
return [t_blob(api_client, blobs)]
|
|
|
|
|
|
def t_blob(
|
|
api_client: _api_client.BaseApiClient, blob: types.BlobImageUnionDict
|
|
) -> types.Blob:
|
|
try:
|
|
import PIL.Image
|
|
|
|
PIL_Image = PIL.Image.Image
|
|
except ImportError:
|
|
PIL_Image = None
|
|
|
|
if not blob:
|
|
raise ValueError('blob is required.')
|
|
|
|
if isinstance(blob, types.Blob):
|
|
return blob
|
|
|
|
if isinstance(blob, dict):
|
|
return types.Blob.model_validate(blob)
|
|
|
|
if PIL_Image is not None and isinstance(blob, PIL_Image):
|
|
return pil_to_blob(blob)
|
|
|
|
raise TypeError(
|
|
f'Could not parse input as Blob. Unsupported blob type: {type(blob)}'
|
|
)
|
|
|
|
|
|
def t_image_blob(
|
|
api_client: _api_client.BaseApiClient, blob: types.BlobImageUnionDict
|
|
) -> types.Blob:
|
|
blob = t_blob(api_client, blob)
|
|
if blob.mime_type and blob.mime_type.startswith('image/'):
|
|
return blob
|
|
raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
|
|
|
|
|
|
def t_audio_blob(
|
|
api_client: _api_client.BaseApiClient, blob: types.BlobOrDict
|
|
) -> types.Blob:
|
|
blob = t_blob(api_client, blob)
|
|
if blob.mime_type and blob.mime_type.startswith('audio/'):
|
|
return blob
|
|
raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
|
|
|
|
|
|
def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
|
|
try:
|
|
import PIL.Image
|
|
|
|
PIL_Image = PIL.Image.Image
|
|
except ImportError:
|
|
PIL_Image = None
|
|
|
|
if part is None:
|
|
raise ValueError('content part is required.')
|
|
if isinstance(part, str):
|
|
return types.Part(text=part)
|
|
if PIL_Image is not None and isinstance(part, PIL_Image):
|
|
return types.Part(inline_data=pil_to_blob(part))
|
|
if isinstance(part, types.File):
|
|
if not part.uri or not part.mime_type:
|
|
raise ValueError('file uri and mime_type are required.')
|
|
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
|
|
if isinstance(part, dict):
|
|
return types.Part.model_validate(part)
|
|
if isinstance(part, types.Part):
|
|
return part
|
|
raise ValueError(f'Unsupported content part type: {type(part)}')
|
|
|
|
|
|
def t_parts(
|
|
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
|
|
) -> list[types.Part]:
|
|
#
|
|
if parts is None or (isinstance(parts, list) and not parts):
|
|
raise ValueError('content parts are required.')
|
|
if isinstance(parts, list):
|
|
return [t_part(part) for part in parts]
|
|
else:
|
|
return [t_part(parts)]
|
|
|
|
|
|
def t_image_predictions(
|
|
client: _api_client.BaseApiClient,
|
|
predictions: Optional[Iterable[Mapping[str, Any]]],
|
|
) -> Optional[list[types.GeneratedImage]]:
|
|
if not predictions:
|
|
return None
|
|
images = []
|
|
for prediction in predictions:
|
|
if prediction.get('image'):
|
|
images.append(
|
|
types.GeneratedImage(
|
|
image=types.Image(
|
|
gcs_uri=prediction['image']['gcsUri'],
|
|
image_bytes=prediction['image']['imageBytes'],
|
|
)
|
|
)
|
|
)
|
|
return images
|
|
|
|
|
|
ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
|
|
|
|
|
|
def t_content(
|
|
client: _api_client.BaseApiClient,
|
|
content: Optional[ContentType],
|
|
) -> types.Content:
|
|
if content is None:
|
|
raise ValueError('content is required.')
|
|
if isinstance(content, types.Content):
|
|
return content
|
|
if isinstance(content, dict):
|
|
try:
|
|
return types.Content.model_validate(content)
|
|
except pydantic.ValidationError:
|
|
possible_part = types.Part.model_validate(content)
|
|
return (
|
|
types.ModelContent(parts=[possible_part])
|
|
if possible_part.function_call
|
|
else types.UserContent(parts=[possible_part])
|
|
)
|
|
if isinstance(content, types.Part):
|
|
return (
|
|
types.ModelContent(parts=[content])
|
|
if content.function_call
|
|
else types.UserContent(parts=[content])
|
|
)
|
|
return types.UserContent(parts=content)
|
|
|
|
|
|
def t_contents_for_embed(
|
|
client: _api_client.BaseApiClient,
|
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
|
) -> Union[list[str], list[types.Content]]:
|
|
if isinstance(contents, list):
|
|
transformed_contents = [t_content(client, content) for content in contents]
|
|
else:
|
|
transformed_contents = [t_content(client, contents)]
|
|
|
|
if client.vertexai:
|
|
text_parts = []
|
|
for content in transformed_contents:
|
|
if content is not None:
|
|
if isinstance(content, dict):
|
|
content = types.Content.model_validate(content)
|
|
if content.parts is not None:
|
|
for part in content.parts:
|
|
if part.text:
|
|
text_parts.append(part.text)
|
|
else:
|
|
logger.warning(
|
|
f'Non-text part found, only returning text parts.'
|
|
)
|
|
return text_parts
|
|
else:
|
|
return transformed_contents
|
|
|
|
|
|
def t_contents(
|
|
client: _api_client.BaseApiClient,
|
|
contents: Optional[
|
|
Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
|
|
],
|
|
) -> list[types.Content]:
|
|
if contents is None or (isinstance(contents, list) and not contents):
|
|
raise ValueError('contents are required.')
|
|
if not isinstance(contents, list):
|
|
return [t_content(client, contents)]
|
|
|
|
try:
|
|
import PIL.Image
|
|
|
|
PIL_Image = PIL.Image.Image
|
|
except ImportError:
|
|
PIL_Image = None
|
|
|
|
result: list[types.Content] = []
|
|
accumulated_parts: list[types.Part] = []
|
|
|
|
def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
|
|
if (
|
|
isinstance(part, str)
|
|
or isinstance(part, types.File)
|
|
or (PIL_Image is not None and isinstance(part, PIL_Image))
|
|
or isinstance(part, types.Part)
|
|
):
|
|
return True
|
|
|
|
if isinstance(part, dict):
|
|
try:
|
|
types.Part.model_validate(part)
|
|
return True
|
|
except pydantic.ValidationError:
|
|
return False
|
|
|
|
return False
|
|
|
|
def _is_user_part(part: types.Part) -> bool:
|
|
return not part.function_call
|
|
|
|
def _are_user_parts(parts: list[types.Part]) -> bool:
|
|
return all(_is_user_part(part) for part in parts)
|
|
|
|
def _append_accumulated_parts_as_content(
|
|
result: list[types.Content],
|
|
accumulated_parts: list[types.Part],
|
|
) -> None:
|
|
if not accumulated_parts:
|
|
return
|
|
result.append(
|
|
types.UserContent(parts=accumulated_parts)
|
|
if _are_user_parts(accumulated_parts)
|
|
else types.ModelContent(parts=accumulated_parts)
|
|
)
|
|
accumulated_parts[:] = []
|
|
|
|
def _handle_current_part(
|
|
result: list[types.Content],
|
|
accumulated_parts: list[types.Part],
|
|
current_part: types.PartUnionDict,
|
|
) -> None:
|
|
current_part = t_part(current_part)
|
|
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
|
|
accumulated_parts.append(current_part)
|
|
else:
|
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
|
accumulated_parts[:] = [current_part]
|
|
|
|
# iterator over contents
|
|
# if content type or content dict, append to result
|
|
# if consecutive part(s),
|
|
# group consecutive user part(s) to a UserContent
|
|
# group consecutive model part(s) to a ModelContent
|
|
# append to result
|
|
# if list, we only accept a list of types.PartUnion
|
|
for content in contents:
|
|
if (
|
|
isinstance(content, types.Content)
|
|
# only allowed inner list is a list of types.PartUnion
|
|
or isinstance(content, list)
|
|
):
|
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
|
if isinstance(content, list):
|
|
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
|
|
else:
|
|
result.append(content)
|
|
elif (_is_part(content)):
|
|
_handle_current_part(result, accumulated_parts, content)
|
|
elif isinstance(content, dict):
|
|
# PactDict is already handled in _is_part
|
|
result.append(types.Content.model_validate(content))
|
|
else:
|
|
raise ValueError(f'Unsupported content type: {type(content)}')
|
|
|
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
|
|
|
return result
|
|
|
|
|
|
def handle_null_fields(schema: dict[str, Any]) -> None:
|
|
"""Process null fields in the schema so it is compatible with OpenAPI.
|
|
|
|
The OpenAPI spec does not support 'type: 'null' in the schema. This function
|
|
handles this case by adding 'nullable: True' to the null field and removing
|
|
the {'type': 'null'} entry.
|
|
|
|
https://swagger.io/docs/specification/v3_0/data-models/data-types/#null
|
|
|
|
Example of schema properties before and after handling null fields:
|
|
Before:
|
|
{
|
|
"name": {
|
|
"title": "Name",
|
|
"type": "string"
|
|
},
|
|
"total_area_sq_mi": {
|
|
"anyOf": [
|
|
{
|
|
"type": "integer"
|
|
},
|
|
{
|
|
"type": "null"
|
|
}
|
|
],
|
|
"default": None,
|
|
"title": "Total Area Sq Mi"
|
|
}
|
|
}
|
|
|
|
After:
|
|
{
|
|
"name": {
|
|
"title": "Name",
|
|
"type": "string"
|
|
},
|
|
"total_area_sq_mi": {
|
|
"type": "integer",
|
|
"nullable": true,
|
|
"default": None,
|
|
"title": "Total Area Sq Mi"
|
|
}
|
|
}
|
|
"""
|
|
if schema.get('type', None) == 'null':
|
|
schema['nullable'] = True
|
|
del schema['type']
|
|
elif 'anyOf' in schema:
|
|
for item in schema['anyOf']:
|
|
if 'type' in item and item['type'] == 'null':
|
|
schema['nullable'] = True
|
|
schema['anyOf'].remove({'type': 'null'})
|
|
if len(schema['anyOf']) == 1:
|
|
# If there is only one type left after removing null, remove the anyOf field.
|
|
for key, val in schema['anyOf'][0].items():
|
|
schema[key] = val
|
|
del schema['anyOf']
|
|
|
|
|
|
def process_schema(
|
|
schema: dict[str, Any],
|
|
client: _api_client.BaseApiClient,
|
|
defs: Optional[dict[str, Any]] = None,
|
|
*,
|
|
order_properties: bool = True,
|
|
) -> None:
|
|
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
|
|
|
- Inlines the $defs.
|
|
|
|
Example of a schema before and after (with mldev):
|
|
Before:
|
|
|
|
`schema`
|
|
|
|
{
|
|
'items': {
|
|
'$ref': '#/$defs/CountryInfo'
|
|
},
|
|
'title': 'Placeholder',
|
|
'type': 'array'
|
|
}
|
|
|
|
|
|
`defs`
|
|
|
|
{
|
|
'CountryInfo': {
|
|
'properties': {
|
|
'continent': {
|
|
'title': 'Continent',
|
|
'type': 'string'
|
|
},
|
|
'gdp': {
|
|
'title': 'Gdp',
|
|
'type': 'integer'}
|
|
},
|
|
}
|
|
'required':['continent', 'gdp'],
|
|
'title': 'CountryInfo',
|
|
'type': 'object'
|
|
}
|
|
}
|
|
|
|
After:
|
|
|
|
`schema`
|
|
{
|
|
'items': {
|
|
'properties': {
|
|
'continent': {
|
|
'title': 'Continent',
|
|
'type': 'string'
|
|
},
|
|
'gdp': {
|
|
'title': 'Gdp',
|
|
'type': 'integer'
|
|
},
|
|
}
|
|
'required':['continent', 'gdp'],
|
|
'title': 'CountryInfo',
|
|
'type': 'object'
|
|
},
|
|
'type': 'array'
|
|
}
|
|
"""
|
|
if schema.get('title') == 'PlaceholderLiteralEnum':
|
|
del schema['title']
|
|
|
|
# Standardize spelling for relevant schema fields. For example, if a dict is
|
|
# provided directly to response_schema, it may use `any_of` instead of `anyOf.
|
|
# Otherwise, model_json_schema() uses `anyOf`.
|
|
for from_name, to_name in [
|
|
('additional_properties', 'additionalProperties'),
|
|
('any_of', 'anyOf'),
|
|
('prefix_items', 'prefixItems'),
|
|
('property_ordering', 'propertyOrdering'),
|
|
]:
|
|
if (value := schema.pop(from_name, None)) is not None:
|
|
schema[to_name] = value
|
|
|
|
if defs is None:
|
|
defs = schema.pop('$defs', {})
|
|
for _, sub_schema in defs.items():
|
|
# We can skip the '$ref' check, because JSON schema forbids a '$ref' from
|
|
# directly referencing another '$ref':
|
|
# https://json-schema.org/understanding-json-schema/structuring#recursion
|
|
process_schema(
|
|
sub_schema, client, defs, order_properties=order_properties
|
|
)
|
|
|
|
handle_null_fields(schema)
|
|
|
|
# After removing null fields, Optional fields with only one possible type
|
|
# will have a $ref key that needs to be flattened
|
|
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
|
if (ref := schema.pop('$ref', None)) is not None:
|
|
schema.update(defs[ref.split('defs/')[-1]])
|
|
|
|
def _recurse(sub_schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""Returns the processed `sub_schema`, resolving its '$ref' if any."""
|
|
if (ref := sub_schema.pop('$ref', None)) is not None:
|
|
sub_schema = defs[ref.split('defs/')[-1]]
|
|
process_schema(sub_schema, client, defs, order_properties=order_properties)
|
|
return sub_schema
|
|
|
|
if (any_of := schema.get('anyOf')) is not None:
|
|
schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
|
|
return
|
|
|
|
schema_type = schema.get('type')
|
|
if isinstance(schema_type, Enum):
|
|
schema_type = schema_type.value
|
|
schema_type = schema_type.upper()
|
|
|
|
# model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
|
|
# For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
|
|
const = schema.get('const')
|
|
if const is not None:
|
|
if schema_type == 'STRING':
|
|
schema['enum'] = [const]
|
|
del schema['const']
|
|
else:
|
|
raise ValueError('Literal values must be strings.')
|
|
|
|
if schema_type == 'OBJECT':
|
|
if (properties := schema.get('properties')) is not None:
|
|
for name, sub_schema in list(properties.items()):
|
|
properties[name] = _recurse(sub_schema)
|
|
if (
|
|
len(properties.items()) > 1
|
|
and order_properties
|
|
and 'propertyOrdering' not in schema
|
|
):
|
|
schema['property_ordering'] = list(properties.keys())
|
|
if (additional := schema.get('additionalProperties')) is not None:
|
|
# It is legal to set 'additionalProperties' to a bool:
|
|
# https://json-schema.org/understanding-json-schema/reference/object#additionalproperties
|
|
if isinstance(additional, dict):
|
|
schema['additionalProperties'] = _recurse(additional)
|
|
elif schema_type == 'ARRAY':
|
|
if (items := schema.get('items')) is not None:
|
|
schema['items'] = _recurse(items)
|
|
if (prefixes := schema.get('prefixItems')) is not None:
|
|
schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]
|
|
|
|
|
|
def _process_enum(
|
|
enum: EnumMeta, client: _api_client.BaseApiClient
|
|
) -> types.Schema:
|
|
for member in enum: # type: ignore
|
|
if not isinstance(member.value, str):
|
|
raise TypeError(
|
|
f'Enum member {member.name} value must be a string, got'
|
|
f' {type(member.value)}'
|
|
)
|
|
|
|
class Placeholder(pydantic.BaseModel):
|
|
placeholder: enum # type: ignore[valid-type]
|
|
|
|
enum_schema = Placeholder.model_json_schema()
|
|
process_schema(enum_schema, client)
|
|
enum_schema = enum_schema['properties']['placeholder']
|
|
return types.Schema.model_validate(enum_schema)
|
|
|
|
|
|
def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
|
|
"""Verifies the schema is of type dict[str, Any] for mypy type checking."""
|
|
return isinstance(origin, dict) and all(
|
|
isinstance(key, str) for key in origin
|
|
)
|
|
|
|
|
|
def t_schema(
|
|
client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
|
|
) -> Optional[types.Schema]:
|
|
if not origin:
|
|
return None
|
|
if isinstance(origin, dict) and _is_type_dict_str_any(origin):
|
|
process_schema(origin, client, order_properties=False)
|
|
return types.Schema.model_validate(origin)
|
|
if isinstance(origin, EnumMeta):
|
|
return _process_enum(origin, client)
|
|
if isinstance(origin, types.Schema):
|
|
if dict(origin) == dict(types.Schema()):
|
|
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
|
|
raise ValueError(f'Unsupported schema type.')
|
|
schema = origin.model_dump(exclude_unset=True)
|
|
process_schema(schema, client, order_properties=False)
|
|
return types.Schema.model_validate(schema)
|
|
|
|
if (
|
|
# in Python 3.9 Generic alias list[int] counts as a type,
|
|
# and breaks issubclass because it's not a class.
|
|
not isinstance(origin, GenericAlias)
|
|
and isinstance(origin, type)
|
|
and issubclass(origin, pydantic.BaseModel)
|
|
):
|
|
schema = origin.model_json_schema()
|
|
process_schema(schema, client)
|
|
return types.Schema.model_validate(schema)
|
|
elif (
|
|
isinstance(origin, GenericAlias)
|
|
or isinstance(origin, type)
|
|
or isinstance(origin, VersionedUnionType)
|
|
or typing.get_origin(origin) in _UNION_TYPES
|
|
):
|
|
|
|
class Placeholder(pydantic.BaseModel):
|
|
placeholder: origin # type: ignore[valid-type]
|
|
|
|
schema = Placeholder.model_json_schema()
|
|
process_schema(schema, client)
|
|
schema = schema['properties']['placeholder']
|
|
return types.Schema.model_validate(schema)
|
|
|
|
raise ValueError(f'Unsupported schema type: {origin}')
|
|
|
|
|
|
def t_speech_config(
|
|
_: _api_client.BaseApiClient,
|
|
origin: Union[types.SpeechConfigUnionDict, Any],
|
|
) -> Optional[types.SpeechConfig]:
|
|
if not origin:
|
|
return None
|
|
if isinstance(origin, types.SpeechConfig):
|
|
return origin
|
|
if isinstance(origin, str):
|
|
return types.SpeechConfig(
|
|
voice_config=types.VoiceConfig(
|
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
|
|
)
|
|
)
|
|
if (
|
|
isinstance(origin, dict)
|
|
and 'voice_config' in origin
|
|
and origin['voice_config'] is not None
|
|
and 'prebuilt_voice_config' in origin['voice_config']
|
|
and origin['voice_config']['prebuilt_voice_config'] is not None
|
|
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
|
|
):
|
|
return types.SpeechConfig(
|
|
voice_config=types.VoiceConfig(
|
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
voice_name=origin['voice_config']['prebuilt_voice_config'].get(
|
|
'voice_name'
|
|
)
|
|
)
|
|
)
|
|
)
|
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
|
|
|
|
|
def t_tool(client: _api_client.BaseApiClient, origin: Any) -> Optional[Union[types.Tool, Any]]:
|
|
if not origin:
|
|
return None
|
|
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
|
return types.Tool(
|
|
function_declarations=[
|
|
types.FunctionDeclaration.from_callable(
|
|
client=client, callable=origin
|
|
)
|
|
]
|
|
)
|
|
elif isinstance(origin, dict):
|
|
return types.Tool.model_validate(origin)
|
|
else:
|
|
return origin
|
|
|
|
|
|
# Only support functions now.
|
|
def t_tools(
|
|
client: _api_client.BaseApiClient, origin: list[Any]
|
|
) -> list[types.Tool]:
|
|
if not origin:
|
|
return []
|
|
function_tool = types.Tool(function_declarations=[])
|
|
tools = []
|
|
for tool in origin:
|
|
transformed_tool = t_tool(client, tool)
|
|
# All functions should be merged into one tool.
|
|
if transformed_tool is not None:
|
|
if (
|
|
transformed_tool.function_declarations
|
|
and function_tool.function_declarations is not None
|
|
):
|
|
function_tool.function_declarations += (
|
|
transformed_tool.function_declarations
|
|
)
|
|
else:
|
|
tools.append(transformed_tool)
|
|
if function_tool.function_declarations:
|
|
tools.append(function_tool)
|
|
return tools
|
|
|
|
|
|
def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
return _resource_name(client, name, collection_identifier='cachedContents')
|
|
|
|
|
|
def t_batch_job_source(client: _api_client.BaseApiClient, src: str) -> types.BatchJobSource:
|
|
if src.startswith('gs://'):
|
|
return types.BatchJobSource(
|
|
format='jsonl',
|
|
gcs_uri=[src],
|
|
)
|
|
elif src.startswith('bq://'):
|
|
return types.BatchJobSource(
|
|
format='bigquery',
|
|
bigquery_uri=src,
|
|
)
|
|
else:
|
|
raise ValueError(f'Unsupported source: {src}')
|
|
|
|
|
|
def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str) -> types.BatchJobDestination:
|
|
if dest.startswith('gs://'):
|
|
return types.BatchJobDestination(
|
|
format='jsonl',
|
|
gcs_uri=dest,
|
|
)
|
|
elif dest.startswith('bq://'):
|
|
return types.BatchJobDestination(
|
|
format='bigquery',
|
|
bigquery_uri=dest,
|
|
)
|
|
else:
|
|
raise ValueError(f'Unsupported destination: {dest}')
|
|
|
|
|
|
def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
|
|
if not client.vertexai:
|
|
return name
|
|
|
|
pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
|
|
if re.match(pattern, name):
|
|
return name.split('/')[-1]
|
|
elif name.isdigit():
|
|
return name
|
|
else:
|
|
raise ValueError(f'Invalid batch job name: {name}.')
|
|
|
|
|
|
LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
|
|
LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
|
|
LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
|
LRO_POLLING_MULTIPLIER = 1.5
|
|
|
|
|
|
def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict[str, Any]) -> Any:
|
|
if (name := struct.get('name')) and '/operations/' in name:
|
|
operation: dict[str, Any] = struct
|
|
total_seconds = 0.0
|
|
delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
|
|
while operation.get('done') != True:
|
|
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
|
|
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
|
# TODO(b/374433890): Replace with LRO module once it's available.
|
|
operation = api_client.request( # type: ignore[assignment]
|
|
http_method='GET', path=name, request_dict={}
|
|
)
|
|
time.sleep(delay_seconds)
|
|
total_seconds += total_seconds
|
|
# Exponential backoff
|
|
delay_seconds = min(
|
|
delay_seconds * LRO_POLLING_MULTIPLIER,
|
|
LRO_POLLING_MAXIMUM_DELAY_SECONDS,
|
|
)
|
|
if error := operation.get('error'):
|
|
raise RuntimeError(
|
|
f'Operation {name} failed with error: {error}.\n{operation}'
|
|
)
|
|
return operation.get('response')
|
|
else:
|
|
return struct
|
|
|
|
|
|
def t_file_name(
|
|
api_client: _api_client.BaseApiClient,
|
|
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
|
|
) -> str:
|
|
# Remove the files/ prefix since it's added to the url path.
|
|
if isinstance(name, types.File):
|
|
name = name.name
|
|
elif isinstance(name, types.Video):
|
|
name = name.uri
|
|
elif isinstance(name, types.GeneratedVideo):
|
|
if name.video is not None:
|
|
name = name.video.uri
|
|
else:
|
|
name = None
|
|
|
|
if name is None:
|
|
raise ValueError('File name is required.')
|
|
|
|
if not isinstance(name, str):
|
|
raise ValueError(
|
|
f'Could not convert object of type `{type(name)}` to a file name.'
|
|
)
|
|
|
|
if name.startswith('https://'):
|
|
suffix = name.split('files/')[1]
|
|
match = re.match('[a-z0-9]+', suffix)
|
|
if match is None:
|
|
raise ValueError(f'Could not extract file name from URI: {name}')
|
|
name = match.group(0)
|
|
elif name.startswith('files/'):
|
|
name = name.split('files/')[1]
|
|
|
|
return name
|
|
|
|
|
|
def t_tuning_job_status(
|
|
api_client: _api_client.BaseApiClient, status: str
|
|
) -> Union[types.JobState, str]:
|
|
if status == 'STATE_UNSPECIFIED':
|
|
return types.JobState.JOB_STATE_UNSPECIFIED
|
|
elif status == 'CREATING':
|
|
return types.JobState.JOB_STATE_RUNNING
|
|
elif status == 'ACTIVE':
|
|
return types.JobState.JOB_STATE_SUCCEEDED
|
|
elif status == 'FAILED':
|
|
return types.JobState.JOB_STATE_FAILED
|
|
else:
|
|
for state in types.JobState:
|
|
if str(state.value) == status:
|
|
return state
|
|
return status
|
|
|
|
|
|
# Some fields don't accept url safe base64 encoding.
|
|
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
|
# format https://cloud.google.com/docs/discovery/type-format.
|
|
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
|
def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
|
|
if not isinstance(data, bytes):
|
|
return data
|
|
return base64.b64encode(data).decode('ascii')
|
|
|
|
|
|
def t_content_strict(content: types.ContentOrDict) -> types.Content:
|
|
if isinstance(content, dict):
|
|
return types.Content.model_validate(content)
|
|
elif isinstance(content, types.Content):
|
|
return content
|
|
else:
|
|
raise ValueError(
|
|
f'Could not convert input (type "{type(content)}") to '
|
|
'`types.Content`'
|
|
)
|
|
|
|
|
|
def t_contents_strict(
|
|
contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict],
|
|
) -> list[types.Content]:
|
|
if isinstance(contents, Sequence):
|
|
return [t_content_strict(content) for content in contents]
|
|
else:
|
|
return [t_content_strict(contents)]
|
|
|
|
|
|
def t_client_content(
|
|
turns: Optional[
|
|
Union[Sequence[types.ContentOrDict], types.ContentOrDict]
|
|
] = None,
|
|
turn_complete: bool = True,
|
|
) -> types.LiveClientContent:
|
|
if turns is None:
|
|
return types.LiveClientContent(turn_complete=turn_complete)
|
|
|
|
try:
|
|
return types.LiveClientContent(
|
|
turns=t_contents_strict(contents=turns),
|
|
turn_complete=turn_complete,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f'Could not convert input (type "{type(turns)}") to '
|
|
'`types.LiveClientContent`'
|
|
) from e
|
|
|
|
|
|
def t_tool_response(
|
|
input: Union[
|
|
types.FunctionResponseOrDict,
|
|
Sequence[types.FunctionResponseOrDict],
|
|
],
|
|
) -> types.LiveClientToolResponse:
|
|
if not input:
|
|
raise ValueError(f'A tool response is required, got: \n{input}')
|
|
|
|
try:
|
|
return types.LiveClientToolResponse(
|
|
function_responses=t_function_responses(function_responses=input)
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f'Could not convert input (type "{type(input)}") to '
|
|
'`types.LiveClientToolResponse`'
|
|
) from e
|