# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Transformers for Google GenAI SDK.""" import base64 from collections.abc import Iterable, Mapping from enum import Enum, EnumMeta import inspect import io import logging import re import sys import time import types as builtin_types import typing from typing import Any, GenericAlias, Optional, Sequence, Union # type: ignore[attr-defined] if typing.TYPE_CHECKING: import PIL.Image import pydantic from . import _api_client from . import types logger = logging.getLogger('google_genai._transformers') if sys.version_info >= (3, 10): VersionedUnionType = builtin_types.UnionType _UNION_TYPES = (typing.Union, builtin_types.UnionType) from typing import TypeGuard else: VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined] _UNION_TYPES = (typing.Union,) from typing_extensions import TypeGuard def _resource_name( client: _api_client.BaseApiClient, resource_name: str, *, collection_identifier: str, collection_hierarchy_depth: int = 2, ) -> str: # pylint: disable=line-too-long """Prepends resource name with project, location, collection_identifier if needed. The collection_identifier will only be prepended if it's not present and the prepending won't violate the collection hierarchy depth. When the prepending condition doesn't meet, returns the input resource_name. Args: client: The API client. resource_name: The user input resource name to be completed. collection_identifier: The collection identifier to be prepended. See collection identifiers in https://google.aip.dev/122. collection_hierarchy_depth: The collection hierarchy depth. Only set this field when the resource has nested collections. For example, `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is `users` and collection_hierarchy_depth is 4. See nested collections in https://google.aip.dev/122. Example: resource_name = 'cachedContents/123' client.vertexai = True client.project = 'bar' client.location = 'us-west1' _resource_name(client, 'cachedContents/123', collection_identifier='cachedContents') returns: 'projects/bar/locations/us-west1/cachedContents/123' Example: resource_name = 'projects/foo/locations/us-central1/cachedContents/123' # resource_name = 'locations/us-central1/cachedContents/123' client.vertexai = True client.project = 'bar' client.location = 'us-west1' _resource_name(client, resource_name, collection_identifier='cachedContents') returns: 'projects/foo/locations/us-central1/cachedContents/123' Example: resource_name = '123' # resource_name = 'cachedContents/123' client.vertexai = False _resource_name(client, resource_name, collection_identifier='cachedContents') returns 'cachedContents/123' Example: resource_name = 'some/wrong/cachedContents/resource/name/123' resource_prefix = 'cachedContents' client.vertexai = False # client.vertexai = True _resource_name(client, resource_name, collection_identifier='cachedContents') returns: 'some/wrong/cachedContents/resource/name/123' Returns: The completed resource name. """ should_prepend_collection_identifier = ( not resource_name.startswith(f'{collection_identifier}/') # Check if prepending the collection identifier won't violate the # collection hierarchy depth. and f'{collection_identifier}/{resource_name}'.count('/') + 1 == collection_hierarchy_depth ) if client.vertexai: if resource_name.startswith('projects/'): return resource_name elif resource_name.startswith('locations/'): return f'projects/{client.project}/{resource_name}' elif resource_name.startswith(f'{collection_identifier}/'): return f'projects/{client.project}/locations/{client.location}/{resource_name}' elif should_prepend_collection_identifier: return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}' else: return resource_name else: if should_prepend_collection_identifier: return f'{collection_identifier}/{resource_name}' else: return resource_name def t_model(client: _api_client.BaseApiClient, model: str) -> str: if not model: raise ValueError('model is required.') if client.vertexai: if ( model.startswith('projects/') or model.startswith('models/') or model.startswith('publishers/') ): return model elif '/' in model: publisher, model_id = model.split('/', 1) return f'publishers/{publisher}/models/{model_id}' else: return f'publishers/google/models/{model}' else: if model.startswith('models/'): return model elif model.startswith('tunedModels/'): return model else: return f'models/{model}' def t_models_url( api_client: _api_client.BaseApiClient, base_models: bool ) -> str: if api_client.vertexai: if base_models: return 'publishers/google/models' else: return 'models' else: if base_models: return 'models' else: return 'tunedModels' def t_extract_models( api_client: _api_client.BaseApiClient, response: dict[str, Any], ) -> list[dict[str, Any]]: if not response: return [] models: Optional[list[dict[str, Any]]] = response.get('models') if models is not None: return models tuned_models: Optional[list[dict[str, Any]]] = response.get('tunedModels') if tuned_models is not None: return tuned_models publisher_models: Optional[list[dict[str, Any]]] = response.get( 'publisherModels' ) if publisher_models is not None: return publisher_models if ( response.get('httpHeaders') is not None and response.get('jsonPayload') is None ): return [] else: logger.warning('Cannot determine the models type.') logger.debug('Cannot determine the models type for response: %s', response) return [] def t_caches_model(api_client: _api_client.BaseApiClient, model: str) -> Optional[str]: model = t_model(api_client, model) if not model: return None if model.startswith('publishers/') and api_client.vertexai: # vertex caches only support model name start with projects. return ( f'projects/{api_client.project}/locations/{api_client.location}/{model}' ) elif model.startswith('models/') and api_client.vertexai: return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}' else: return model def pil_to_blob(img: Any) -> types.Blob: PngImagePlugin: Optional[builtin_types.ModuleType] try: import PIL.PngImagePlugin PngImagePlugin = PIL.PngImagePlugin except ImportError: PngImagePlugin = None bytesio = io.BytesIO() if ( PngImagePlugin is not None and isinstance(img, PngImagePlugin.PngImageFile) or img.mode == 'RGBA' ): img.save(bytesio, format='PNG') mime_type = 'image/png' else: img.save(bytesio, format='JPEG') mime_type = 'image/jpeg' bytesio.seek(0) data = bytesio.read() return types.Blob(mime_type=mime_type, data=data) def t_function_response( function_response: types.FunctionResponseOrDict, ) -> types.FunctionResponse: if not function_response: raise ValueError('function_response is required.') if isinstance(function_response, dict): return types.FunctionResponse.model_validate(function_response) elif isinstance(function_response, types.FunctionResponse): return function_response else: raise TypeError( f'Could not parse input as FunctionResponse. Unsupported' f' function_response type: {type(function_response)}' ) def t_function_responses( function_responses: Union[ types.FunctionResponseOrDict, Sequence[types.FunctionResponseOrDict], ], ) -> list[types.FunctionResponse]: if not function_responses: raise ValueError('function_responses are required.') if isinstance(function_responses, Sequence): return [t_function_response(response) for response in function_responses] else: return [t_function_response(function_responses)] def t_blobs( api_client: _api_client.BaseApiClient, blobs: Union[types.BlobImageUnionDict, list[types.BlobImageUnionDict]], ) -> list[types.Blob]: if isinstance(blobs, list): return [t_blob(api_client, blob) for blob in blobs] else: return [t_blob(api_client, blobs)] def t_blob( api_client: _api_client.BaseApiClient, blob: types.BlobImageUnionDict ) -> types.Blob: try: import PIL.Image PIL_Image = PIL.Image.Image except ImportError: PIL_Image = None if not blob: raise ValueError('blob is required.') if isinstance(blob, types.Blob): return blob if isinstance(blob, dict): return types.Blob.model_validate(blob) if PIL_Image is not None and isinstance(blob, PIL_Image): return pil_to_blob(blob) raise TypeError( f'Could not parse input as Blob. Unsupported blob type: {type(blob)}' ) def t_image_blob( api_client: _api_client.BaseApiClient, blob: types.BlobImageUnionDict ) -> types.Blob: blob = t_blob(api_client, blob) if blob.mime_type and blob.mime_type.startswith('image/'): return blob raise ValueError(f'Unsupported mime type: {blob.mime_type!r}') def t_audio_blob( api_client: _api_client.BaseApiClient, blob: types.BlobOrDict ) -> types.Blob: blob = t_blob(api_client, blob) if blob.mime_type and blob.mime_type.startswith('audio/'): return blob raise ValueError(f'Unsupported mime type: {blob.mime_type!r}') def t_part(part: Optional[types.PartUnionDict]) -> types.Part: try: import PIL.Image PIL_Image = PIL.Image.Image except ImportError: PIL_Image = None if part is None: raise ValueError('content part is required.') if isinstance(part, str): return types.Part(text=part) if PIL_Image is not None and isinstance(part, PIL_Image): return types.Part(inline_data=pil_to_blob(part)) if isinstance(part, types.File): if not part.uri or not part.mime_type: raise ValueError('file uri and mime_type are required.') return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) if isinstance(part, dict): return types.Part.model_validate(part) if isinstance(part, types.Part): return part raise ValueError(f'Unsupported content part type: {type(part)}') def t_parts( parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]], ) -> list[types.Part]: # if parts is None or (isinstance(parts, list) and not parts): raise ValueError('content parts are required.') if isinstance(parts, list): return [t_part(part) for part in parts] else: return [t_part(parts)] def t_image_predictions( client: _api_client.BaseApiClient, predictions: Optional[Iterable[Mapping[str, Any]]], ) -> Optional[list[types.GeneratedImage]]: if not predictions: return None images = [] for prediction in predictions: if prediction.get('image'): images.append( types.GeneratedImage( image=types.Image( gcs_uri=prediction['image']['gcsUri'], image_bytes=prediction['image']['imageBytes'], ) ) ) return images ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict] def t_content( client: _api_client.BaseApiClient, content: Optional[ContentType], ) -> types.Content: if content is None: raise ValueError('content is required.') if isinstance(content, types.Content): return content if isinstance(content, dict): try: return types.Content.model_validate(content) except pydantic.ValidationError: possible_part = types.Part.model_validate(content) return ( types.ModelContent(parts=[possible_part]) if possible_part.function_call else types.UserContent(parts=[possible_part]) ) if isinstance(content, types.Part): return ( types.ModelContent(parts=[content]) if content.function_call else types.UserContent(parts=[content]) ) return types.UserContent(parts=content) def t_contents_for_embed( client: _api_client.BaseApiClient, contents: Union[list[types.Content], list[types.ContentDict], ContentType], ) -> Union[list[str], list[types.Content]]: if isinstance(contents, list): transformed_contents = [t_content(client, content) for content in contents] else: transformed_contents = [t_content(client, contents)] if client.vertexai: text_parts = [] for content in transformed_contents: if content is not None: if isinstance(content, dict): content = types.Content.model_validate(content) if content.parts is not None: for part in content.parts: if part.text: text_parts.append(part.text) else: logger.warning( f'Non-text part found, only returning text parts.' ) return text_parts else: return transformed_contents def t_contents( client: _api_client.BaseApiClient, contents: Optional[ Union[types.ContentListUnion, types.ContentListUnionDict, types.Content] ], ) -> list[types.Content]: if contents is None or (isinstance(contents, list) and not contents): raise ValueError('contents are required.') if not isinstance(contents, list): return [t_content(client, contents)] try: import PIL.Image PIL_Image = PIL.Image.Image except ImportError: PIL_Image = None result: list[types.Content] = [] accumulated_parts: list[types.Part] = [] def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]: if ( isinstance(part, str) or isinstance(part, types.File) or (PIL_Image is not None and isinstance(part, PIL_Image)) or isinstance(part, types.Part) ): return True if isinstance(part, dict): try: types.Part.model_validate(part) return True except pydantic.ValidationError: return False return False def _is_user_part(part: types.Part) -> bool: return not part.function_call def _are_user_parts(parts: list[types.Part]) -> bool: return all(_is_user_part(part) for part in parts) def _append_accumulated_parts_as_content( result: list[types.Content], accumulated_parts: list[types.Part], ) -> None: if not accumulated_parts: return result.append( types.UserContent(parts=accumulated_parts) if _are_user_parts(accumulated_parts) else types.ModelContent(parts=accumulated_parts) ) accumulated_parts[:] = [] def _handle_current_part( result: list[types.Content], accumulated_parts: list[types.Part], current_part: types.PartUnionDict, ) -> None: current_part = t_part(current_part) if _is_user_part(current_part) == _are_user_parts(accumulated_parts): accumulated_parts.append(current_part) else: _append_accumulated_parts_as_content(result, accumulated_parts) accumulated_parts[:] = [current_part] # iterator over contents # if content type or content dict, append to result # if consecutive part(s), # group consecutive user part(s) to a UserContent # group consecutive model part(s) to a ModelContent # append to result # if list, we only accept a list of types.PartUnion for content in contents: if ( isinstance(content, types.Content) # only allowed inner list is a list of types.PartUnion or isinstance(content, list) ): _append_accumulated_parts_as_content(result, accumulated_parts) if isinstance(content, list): result.append(types.UserContent(parts=content)) # type: ignore[arg-type] else: result.append(content) elif (_is_part(content)): _handle_current_part(result, accumulated_parts, content) elif isinstance(content, dict): # PactDict is already handled in _is_part result.append(types.Content.model_validate(content)) else: raise ValueError(f'Unsupported content type: {type(content)}') _append_accumulated_parts_as_content(result, accumulated_parts) return result def handle_null_fields(schema: dict[str, Any]) -> None: """Process null fields in the schema so it is compatible with OpenAPI. The OpenAPI spec does not support 'type: 'null' in the schema. This function handles this case by adding 'nullable: True' to the null field and removing the {'type': 'null'} entry. https://swagger.io/docs/specification/v3_0/data-models/data-types/#null Example of schema properties before and after handling null fields: Before: { "name": { "title": "Name", "type": "string" }, "total_area_sq_mi": { "anyOf": [ { "type": "integer" }, { "type": "null" } ], "default": None, "title": "Total Area Sq Mi" } } After: { "name": { "title": "Name", "type": "string" }, "total_area_sq_mi": { "type": "integer", "nullable": true, "default": None, "title": "Total Area Sq Mi" } } """ if schema.get('type', None) == 'null': schema['nullable'] = True del schema['type'] elif 'anyOf' in schema: for item in schema['anyOf']: if 'type' in item and item['type'] == 'null': schema['nullable'] = True schema['anyOf'].remove({'type': 'null'}) if len(schema['anyOf']) == 1: # If there is only one type left after removing null, remove the anyOf field. for key, val in schema['anyOf'][0].items(): schema[key] = val del schema['anyOf'] def process_schema( schema: dict[str, Any], client: _api_client.BaseApiClient, defs: Optional[dict[str, Any]] = None, *, order_properties: bool = True, ) -> None: """Updates the schema and each sub-schema inplace to be API-compatible. - Inlines the $defs. Example of a schema before and after (with mldev): Before: `schema` { 'items': { '$ref': '#/$defs/CountryInfo' }, 'title': 'Placeholder', 'type': 'array' } `defs` { 'CountryInfo': { 'properties': { 'continent': { 'title': 'Continent', 'type': 'string' }, 'gdp': { 'title': 'Gdp', 'type': 'integer'} }, } 'required':['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object' } } After: `schema` { 'items': { 'properties': { 'continent': { 'title': 'Continent', 'type': 'string' }, 'gdp': { 'title': 'Gdp', 'type': 'integer' }, } 'required':['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object' }, 'type': 'array' } """ if schema.get('title') == 'PlaceholderLiteralEnum': del schema['title'] # Standardize spelling for relevant schema fields. For example, if a dict is # provided directly to response_schema, it may use `any_of` instead of `anyOf. # Otherwise, model_json_schema() uses `anyOf`. for from_name, to_name in [ ('additional_properties', 'additionalProperties'), ('any_of', 'anyOf'), ('prefix_items', 'prefixItems'), ('property_ordering', 'propertyOrdering'), ]: if (value := schema.pop(from_name, None)) is not None: schema[to_name] = value if defs is None: defs = schema.pop('$defs', {}) for _, sub_schema in defs.items(): # We can skip the '$ref' check, because JSON schema forbids a '$ref' from # directly referencing another '$ref': # https://json-schema.org/understanding-json-schema/structuring#recursion process_schema( sub_schema, client, defs, order_properties=order_properties ) handle_null_fields(schema) # After removing null fields, Optional fields with only one possible type # will have a $ref key that needs to be flattened # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'} if (ref := schema.pop('$ref', None)) is not None: schema.update(defs[ref.split('defs/')[-1]]) def _recurse(sub_schema: dict[str, Any]) -> dict[str, Any]: """Returns the processed `sub_schema`, resolving its '$ref' if any.""" if (ref := sub_schema.pop('$ref', None)) is not None: sub_schema = defs[ref.split('defs/')[-1]] process_schema(sub_schema, client, defs, order_properties=order_properties) return sub_schema if (any_of := schema.get('anyOf')) is not None: schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of] return schema_type = schema.get('type') if isinstance(schema_type, Enum): schema_type = schema_type.value schema_type = schema_type.upper() # model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field # For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'} const = schema.get('const') if const is not None: if schema_type == 'STRING': schema['enum'] = [const] del schema['const'] else: raise ValueError('Literal values must be strings.') if schema_type == 'OBJECT': if (properties := schema.get('properties')) is not None: for name, sub_schema in list(properties.items()): properties[name] = _recurse(sub_schema) if ( len(properties.items()) > 1 and order_properties and 'propertyOrdering' not in schema ): schema['property_ordering'] = list(properties.keys()) if (additional := schema.get('additionalProperties')) is not None: # It is legal to set 'additionalProperties' to a bool: # https://json-schema.org/understanding-json-schema/reference/object#additionalproperties if isinstance(additional, dict): schema['additionalProperties'] = _recurse(additional) elif schema_type == 'ARRAY': if (items := schema.get('items')) is not None: schema['items'] = _recurse(items) if (prefixes := schema.get('prefixItems')) is not None: schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes] def _process_enum( enum: EnumMeta, client: _api_client.BaseApiClient ) -> types.Schema: for member in enum: # type: ignore if not isinstance(member.value, str): raise TypeError( f'Enum member {member.name} value must be a string, got' f' {type(member.value)}' ) class Placeholder(pydantic.BaseModel): placeholder: enum # type: ignore[valid-type] enum_schema = Placeholder.model_json_schema() process_schema(enum_schema, client) enum_schema = enum_schema['properties']['placeholder'] return types.Schema.model_validate(enum_schema) def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]: """Verifies the schema is of type dict[str, Any] for mypy type checking.""" return isinstance(origin, dict) and all( isinstance(key, str) for key in origin ) def t_schema( client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any] ) -> Optional[types.Schema]: if not origin: return None if isinstance(origin, dict) and _is_type_dict_str_any(origin): process_schema(origin, client, order_properties=False) return types.Schema.model_validate(origin) if isinstance(origin, EnumMeta): return _process_enum(origin, client) if isinstance(origin, types.Schema): if dict(origin) == dict(types.Schema()): # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation raise ValueError(f'Unsupported schema type.') schema = origin.model_dump(exclude_unset=True) process_schema(schema, client, order_properties=False) return types.Schema.model_validate(schema) if ( # in Python 3.9 Generic alias list[int] counts as a type, # and breaks issubclass because it's not a class. not isinstance(origin, GenericAlias) and isinstance(origin, type) and issubclass(origin, pydantic.BaseModel) ): schema = origin.model_json_schema() process_schema(schema, client) return types.Schema.model_validate(schema) elif ( isinstance(origin, GenericAlias) or isinstance(origin, type) or isinstance(origin, VersionedUnionType) or typing.get_origin(origin) in _UNION_TYPES ): class Placeholder(pydantic.BaseModel): placeholder: origin # type: ignore[valid-type] schema = Placeholder.model_json_schema() process_schema(schema, client) schema = schema['properties']['placeholder'] return types.Schema.model_validate(schema) raise ValueError(f'Unsupported schema type: {origin}') def t_speech_config( _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any], ) -> Optional[types.SpeechConfig]: if not origin: return None if isinstance(origin, types.SpeechConfig): return origin if isinstance(origin, str): return types.SpeechConfig( voice_config=types.VoiceConfig( prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin) ) ) if ( isinstance(origin, dict) and 'voice_config' in origin and origin['voice_config'] is not None and 'prebuilt_voice_config' in origin['voice_config'] and origin['voice_config']['prebuilt_voice_config'] is not None and 'voice_name' in origin['voice_config']['prebuilt_voice_config'] ): return types.SpeechConfig( voice_config=types.VoiceConfig( prebuilt_voice_config=types.PrebuiltVoiceConfig( voice_name=origin['voice_config']['prebuilt_voice_config'].get( 'voice_name' ) ) ) ) raise ValueError(f'Unsupported speechConfig type: {type(origin)}') def t_tool(client: _api_client.BaseApiClient, origin: Any) -> Optional[Union[types.Tool, Any]]: if not origin: return None if inspect.isfunction(origin) or inspect.ismethod(origin): return types.Tool( function_declarations=[ types.FunctionDeclaration.from_callable( client=client, callable=origin ) ] ) elif isinstance(origin, dict): return types.Tool.model_validate(origin) else: return origin # Only support functions now. def t_tools( client: _api_client.BaseApiClient, origin: list[Any] ) -> list[types.Tool]: if not origin: return [] function_tool = types.Tool(function_declarations=[]) tools = [] for tool in origin: transformed_tool = t_tool(client, tool) # All functions should be merged into one tool. if transformed_tool is not None: if ( transformed_tool.function_declarations and function_tool.function_declarations is not None ): function_tool.function_declarations += ( transformed_tool.function_declarations ) else: tools.append(transformed_tool) if function_tool.function_declarations: tools.append(function_tool) return tools def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str: return _resource_name(client, name, collection_identifier='cachedContents') def t_batch_job_source(client: _api_client.BaseApiClient, src: str) -> types.BatchJobSource: if src.startswith('gs://'): return types.BatchJobSource( format='jsonl', gcs_uri=[src], ) elif src.startswith('bq://'): return types.BatchJobSource( format='bigquery', bigquery_uri=src, ) else: raise ValueError(f'Unsupported source: {src}') def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str) -> types.BatchJobDestination: if dest.startswith('gs://'): return types.BatchJobDestination( format='jsonl', gcs_uri=dest, ) elif dest.startswith('bq://'): return types.BatchJobDestination( format='bigquery', bigquery_uri=dest, ) else: raise ValueError(f'Unsupported destination: {dest}') def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str: if not client.vertexai: return name pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$' if re.match(pattern, name): return name.split('/')[-1] elif name.isdigit(): return name else: raise ValueError(f'Invalid batch job name: {name}.') LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0 LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0 LRO_POLLING_TIMEOUT_SECONDS = 900.0 LRO_POLLING_MULTIPLIER = 1.5 def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict[str, Any]) -> Any: if (name := struct.get('name')) and '/operations/' in name: operation: dict[str, Any] = struct total_seconds = 0.0 delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS while operation.get('done') != True: if total_seconds > LRO_POLLING_TIMEOUT_SECONDS: raise RuntimeError(f'Operation {name} timed out.\n{operation}') # TODO(b/374433890): Replace with LRO module once it's available. operation = api_client.request( # type: ignore[assignment] http_method='GET', path=name, request_dict={} ) time.sleep(delay_seconds) total_seconds += total_seconds # Exponential backoff delay_seconds = min( delay_seconds * LRO_POLLING_MULTIPLIER, LRO_POLLING_MAXIMUM_DELAY_SECONDS, ) if error := operation.get('error'): raise RuntimeError( f'Operation {name} failed with error: {error}.\n{operation}' ) return operation.get('response') else: return struct def t_file_name( api_client: _api_client.BaseApiClient, name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]], ) -> str: # Remove the files/ prefix since it's added to the url path. if isinstance(name, types.File): name = name.name elif isinstance(name, types.Video): name = name.uri elif isinstance(name, types.GeneratedVideo): if name.video is not None: name = name.video.uri else: name = None if name is None: raise ValueError('File name is required.') if not isinstance(name, str): raise ValueError( f'Could not convert object of type `{type(name)}` to a file name.' ) if name.startswith('https://'): suffix = name.split('files/')[1] match = re.match('[a-z0-9]+', suffix) if match is None: raise ValueError(f'Could not extract file name from URI: {name}') name = match.group(0) elif name.startswith('files/'): name = name.split('files/')[1] return name def t_tuning_job_status( api_client: _api_client.BaseApiClient, status: str ) -> Union[types.JobState, str]: if status == 'STATE_UNSPECIFIED': return types.JobState.JOB_STATE_UNSPECIFIED elif status == 'CREATING': return types.JobState.JOB_STATE_RUNNING elif status == 'ACTIVE': return types.JobState.JOB_STATE_SUCCEEDED elif status == 'FAILED': return types.JobState.JOB_STATE_FAILED else: for state in types.JobState: if str(state.value) == status: return state return status # Some fields don't accept url safe base64 encoding. # We shouldn't use this transformer if the backend adhere to Cloud Type # format https://cloud.google.com/docs/discovery/type-format. # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue. def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str: if not isinstance(data, bytes): return data return base64.b64encode(data).decode('ascii') def t_content_strict(content: types.ContentOrDict) -> types.Content: if isinstance(content, dict): return types.Content.model_validate(content) elif isinstance(content, types.Content): return content else: raise ValueError( f'Could not convert input (type "{type(content)}") to ' '`types.Content`' ) def t_contents_strict( contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict], ) -> list[types.Content]: if isinstance(contents, Sequence): return [t_content_strict(content) for content in contents] else: return [t_content_strict(contents)] def t_client_content( turns: Optional[ Union[Sequence[types.ContentOrDict], types.ContentOrDict] ] = None, turn_complete: bool = True, ) -> types.LiveClientContent: if turns is None: return types.LiveClientContent(turn_complete=turn_complete) try: return types.LiveClientContent( turns=t_contents_strict(contents=turns), turn_complete=turn_complete, ) except Exception as e: raise ValueError( f'Could not convert input (type "{type(turns)}") to ' '`types.LiveClientContent`' ) from e def t_tool_response( input: Union[ types.FunctionResponseOrDict, Sequence[types.FunctionResponseOrDict], ], ) -> types.LiveClientToolResponse: if not input: raise ValueError(f'A tool response is required, got: \n{input}') try: return types.LiveClientToolResponse( function_responses=t_function_responses(function_responses=input) ) except Exception as e: raise ValueError( f'Could not convert input (type "{type(input)}") to ' '`types.LiveClientToolResponse`' ) from e