structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Google Gen AI SDK"""
from .client import Client
from . import version
__version__ = version.__version__
__all__ = ['Client']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for the API Modules of the Google Gen AI SDK."""
from typing import Optional
from . import _api_client
class BaseModule:
def __init__(self, api_client_: _api_client.BaseApiClient):
self._api_client = api_client_
@property
def vertexai(self) -> Optional[bool]:
return self._api_client.vertexai

View File

@@ -0,0 +1,284 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import sys
import types as builtin_types
import typing
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
import pydantic
from . import _extra_utils
from . import types
if sys.version_info >= (3, 10):
VersionedUnionType = builtin_types.UnionType
else:
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
_py_builtin_type_to_schema_type = {
str: types.Type.STRING,
int: types.Type.INTEGER,
float: types.Type.NUMBER,
bool: types.Type.BOOLEAN,
list: types.Type.ARRAY,
dict: types.Type.OBJECT,
}
def _is_builtin_primitive_or_compound(
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
) -> bool:
return annotation in _py_builtin_type_to_schema_type.keys()
def _is_default_value_compatible(
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
) -> bool:
# None type is expected to be handled external to this function
if _is_builtin_primitive_or_compound(annotation):
return isinstance(default_value, annotation)
if (
isinstance(annotation, _GenericAlias)
or isinstance(annotation, builtin_types.GenericAlias)
or isinstance(annotation, VersionedUnionType)
):
origin = get_origin(annotation)
if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
return any(
_is_default_value_compatible(default_value, arg)
for arg in get_args(annotation)
)
if origin is dict: # type: ignore[comparison-overlap]
return isinstance(default_value, dict)
if origin is list: # type: ignore[comparison-overlap]
if not isinstance(default_value, list):
return False
# most tricky case, element in list is union type
# need to apply any logic within all
# see test case test_generic_alias_complex_array_with_default_value
# a: typing.List[int | str | float | bool]
# default_value: [1, 'a', 1.1, True]
return all(
any(
_is_default_value_compatible(item, arg)
for arg in get_args(annotation)
)
for item in default_value
)
if origin is Literal: # type: ignore[comparison-overlap]
return default_value in get_args(annotation)
# return False for any other unrecognized annotation
return False
def _parse_schema_from_parameter(
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
param: inspect.Parameter,
func_name: str,
) -> types.Schema:
"""parse schema from parameter.
from the simplest case to the most complex case.
"""
schema = types.Schema()
default_value_error_msg = (
f'Default value {param.default} of parameter {param} of function'
f' {func_name} is not compatible with the parameter annotation'
f' {param.annotation}.'
)
if _is_builtin_primitive_or_compound(param.annotation):
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
schema.type = _py_builtin_type_to_schema_type[param.annotation]
return schema
if (
isinstance(param.annotation, VersionedUnionType)
# only parse simple UnionType, example int | str | float | bool
# complex UnionType will be invoked in raise branch
and all(
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
for arg in get_args(param.annotation)
)
):
schema.type = _py_builtin_type_to_schema_type[dict]
schema.any_of = []
unique_types = set()
for arg in get_args(param.annotation):
if arg.__name__ == 'NoneType': # Optional type
schema.nullable = True
continue
schema_in_any_of = _parse_schema_from_parameter(
api_option,
inspect.Parameter(
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
),
func_name,
)
if (
schema_in_any_of.model_dump_json(exclude_none=True)
not in unique_types
):
schema.any_of.append(schema_in_any_of)
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
if len(schema.any_of) == 1: # param: list | None -> Array
schema.type = schema.any_of[0].type
schema.any_of = None
if (
param.default is not inspect.Parameter.empty
and param.default is not None
):
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, builtin_types.GenericAlias
):
origin = get_origin(param.annotation)
args = get_args(param.annotation)
if origin is dict:
schema.type = _py_builtin_type_to_schema_type[dict]
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if origin is Literal:
if not all(isinstance(arg, str) for arg in args):
raise ValueError(
f'Literal type {param.annotation} must be a list of strings.'
)
schema.type = _py_builtin_type_to_schema_type[str]
schema.enum = list(args)
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if origin is list:
schema.type = _py_builtin_type_to_schema_type[list]
schema.items = _parse_schema_from_parameter(
api_option,
inspect.Parameter(
'item',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=args[0],
),
func_name,
)
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
if origin is Union:
schema.any_of = []
schema.type = _py_builtin_type_to_schema_type[dict]
unique_types = set()
for arg in args:
# The first check is for NoneType in Python 3.9, since the __name__
# attribute is not available in Python 3.9
if type(arg) is type(None) or (
hasattr(arg, '__name__') and arg.__name__ == 'NoneType'
): # Optional type
schema.nullable = True
continue
schema_in_any_of = _parse_schema_from_parameter(
api_option,
inspect.Parameter(
'item',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=arg,
),
func_name,
)
if (
len(param.annotation.__args__) == 2
and type(None) in param.annotation.__args__
): # Optional type
for optional_arg in param.annotation.__args__:
if (
hasattr(optional_arg, '__origin__')
and optional_arg.__origin__ is list
):
# Optional type with list, for example Optional[list[str]]
schema.items = schema_in_any_of.items
if (
schema_in_any_of.model_dump_json(exclude_none=True)
not in unique_types
):
schema.any_of.append(schema_in_any_of)
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
schema.type = schema.any_of[0].type
schema.any_of = None
if (
param.default is not None
and param.default is not inspect.Parameter.empty
):
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
return schema
# all other generic alias will be invoked in raise branch
if (
# for user defined class, we only support pydantic model
_extra_utils.is_annotation_pydantic_model(param.annotation)
):
if (
param.default is not inspect.Parameter.empty
and param.default is not None
):
schema.default = param.default
schema.type = _py_builtin_type_to_schema_type[dict]
schema.properties = {}
for field_name, field_info in param.annotation.model_fields.items():
schema.properties[field_name] = _parse_schema_from_parameter(
api_option,
inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=field_info.annotation,
),
func_name,
)
schema.required = _get_required_fields(schema)
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling.Automatic function calling works best with'
' simpler function signature schema,consider manually parse your'
f' function declaration for function {func_name}.'
)
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
if not schema.properties:
return None
return [
field_name
for field_name, field_schema in schema.properties.items()
if not field_schema.nullable and field_schema.default is None
]

View File

@@ -0,0 +1,318 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utilities for the SDK."""
import base64
import datetime
import enum
import functools
import typing
from typing import Any, Callable, Optional, Union
import uuid
import warnings
import pydantic
from pydantic import alias_generators
from . import _api_client
from . import errors
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
"""Examples:
set_value_by_path({}, ['a', 'b'], v)
-> {'a': {'b': v}}
set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
-> {'a': {'b': [{'c': v1}, {'c': v2}]}}
set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
-> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
"""
if value is None:
return
for i, key in enumerate(keys[:-1]):
if key.endswith('[]'):
key_name = key[:-2]
if data is not None and key_name not in data:
if isinstance(value, list):
data[key_name] = [{} for _ in range(len(value))]
else:
raise ValueError(
f'value {value} must be a list given an array path {key}'
)
if isinstance(value, list) and data is not None:
for j, d in enumerate(data[key_name]):
set_value_by_path(d, keys[i + 1 :], value[j])
else:
if data is not None:
for d in data[key_name]:
set_value_by_path(d, keys[i + 1 :], value)
return
elif key.endswith('[0]'):
key_name = key[:-3]
if data is not None and key_name not in data:
data[key_name] = [{}]
if data is not None:
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
return
if data is not None:
data = data.setdefault(key, {})
if data is not None:
existing_data = data.get(keys[-1])
# If there is an existing value, merge, not overwrite.
if existing_data is not None:
# Don't overwrite existing non-empty value with new empty value.
# This is triggered when handling tuning datasets.
if not value:
pass
# Don't fail when overwriting value with same value
elif value == existing_data:
pass
# Instead of overwriting dictionary with another dictionary, merge them.
# This is important for handling training and validation datasets in tuning.
elif isinstance(existing_data, dict) and isinstance(value, dict):
# Merging dictionaries. Consider deep merging in the future.
existing_data.update(value)
else:
raise ValueError(
f'Cannot set value for an existing key. Key: {keys[-1]};'
f' Existing value: {existing_data}; New value: {value}.'
)
else:
data[keys[-1]] = value
def get_value_by_path(data: Any, keys: list[str]) -> Any:
"""Examples:
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
-> v
get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
-> [v1, v2]
"""
if keys == ['_self']:
return data
for i, key in enumerate(keys):
if not data:
return None
if key.endswith('[]'):
key_name = key[:-2]
if key_name in data:
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
else:
return None
elif key.endswith('[0]'):
key_name = key[:-3]
if key_name in data and data[key_name]:
return get_value_by_path(data[key_name][0], keys[i + 1 :])
else:
return None
else:
if key in data:
data = data[key]
elif isinstance(data, BaseModel) and hasattr(data, key):
data = getattr(data, key)
else:
return None
return data
def convert_to_dict(obj: object) -> Any:
"""Recursively converts a given object to a dictionary.
If the object is a Pydantic model, it uses the model's `model_dump()` method.
Args:
obj: The object to convert.
Returns:
A dictionary representation of the object, a list of objects if a list is
passed, or the object itself if it is not a dictionary, list, or Pydantic
model.
"""
if isinstance(obj, pydantic.BaseModel):
return obj.model_dump(exclude_none=True)
elif isinstance(obj, dict):
return {key: convert_to_dict(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_to_dict(item) for item in obj]
else:
return obj
def _remove_extra_fields(
model: Any, response: dict[str, object]
) -> None:
"""Removes extra fields from the response that are not in the model.
Mutates the response in place.
"""
key_values = list(response.items())
for key, value in key_values:
# Need to convert to snake case to match model fields names
# ex: UsageMetadata
alias_map = {
field_info.alias: key for key, field_info in model.model_fields.items()
}
if key not in model.model_fields and key not in alias_map:
response.pop(key)
continue
key = alias_map.get(key, key)
annotation = model.model_fields[key].annotation
# Get the BaseModel if Optional
if typing.get_origin(annotation) is Union:
annotation = typing.get_args(annotation)[0]
# if dict, assume BaseModel but also check that field type is not dict
# example: FunctionCall.args
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
_remove_extra_fields(annotation, value)
elif isinstance(value, list):
for item in value:
# assume a list of dict is list of BaseModel
if isinstance(item, dict):
_remove_extra_fields(typing.get_args(annotation)[0], item)
T = typing.TypeVar('T', bound='BaseModel')
class BaseModel(pydantic.BaseModel):
model_config = pydantic.ConfigDict(
alias_generator=alias_generators.to_camel,
populate_by_name=True,
from_attributes=True,
protected_namespaces=(),
extra='forbid',
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
arbitrary_types_allowed=True,
ser_json_bytes='base64',
val_json_bytes='base64',
ignored_types=(typing.TypeVar,)
)
@classmethod
def _from_response(
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
) -> T:
# To maintain forward compatibility, we need to remove extra fields from
# the response.
# We will provide another mechanism to allow users to access these fields.
_remove_extra_fields(cls, response)
validated_response = cls.model_validate(response)
return validated_response
def to_json_dict(self) -> dict[str, object]:
return self.model_dump(exclude_none=True, mode='json')
class CaseInSensitiveEnum(str, enum.Enum):
"""Case insensitive enum."""
@classmethod
def _missing_(cls, value: Any) -> Any:
try:
return cls[value.upper()] # Try to access directly with uppercase
except KeyError:
try:
return cls[value.lower()] # Try to access directly with lowercase
except KeyError:
warnings.warn(f"{value} is not a valid {cls.__name__}")
try:
# Creating a enum instance based on the value
# We need to use super() to avoid infinite recursion.
unknown_enum_val = super().__new__(cls, value)
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
unknown_enum_val._value_ = value # pylint: disable=protected-access
return unknown_enum_val
except:
return None
def timestamped_unique_name() -> str:
"""Composes a timestamped unique name.
Returns:
A string representing a unique name.
"""
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
unique_id = uuid.uuid4().hex[0:5]
return f'{timestamp}_{unique_id}'
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
"""Converts unserializable types in dict to json.dumps() compatible types.
This function is called in models.py after calling convert_to_dict(). The
convert_to_dict() can convert pydantic object to dict. However, the input to
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
of converters). So they may be bytes in the dict and they are out of
`ser_json_bytes` control in model_dump(mode='json') called in
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
Returns:
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
to compatible type (e.g. base64 encoded string, isoformat date string).
"""
processed_data: dict[str, object] = {}
if not isinstance(data, dict):
return data
for key, value in data.items():
if isinstance(value, bytes):
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
elif isinstance(value, datetime.datetime):
processed_data[key] = value.isoformat()
elif isinstance(value, dict):
processed_data[key] = encode_unserializable_types(value)
elif isinstance(value, list):
if all(isinstance(v, bytes) for v in value):
processed_data[key] = [
base64.urlsafe_b64encode(v).decode('ascii') for v in value
]
if all(isinstance(v, datetime.datetime) for v in value):
processed_data[key] = [v.isoformat() for v in value]
else:
processed_data[key] = [encode_unserializable_types(v) for v in value]
else:
processed_data[key] = value
return processed_data
def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Experimental warning, only warns once."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
warning_done = False
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal warning_done
if not warning_done:
warning_done = True
warnings.warn(
message=message,
category=errors.ExperimentalWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,403 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Extra utils depending on types that are shared between sync and async modules."""
import inspect
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
import pydantic
from . import _common
from . import errors
from . import types
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
logger = logging.getLogger('google_genai.models')
def _create_generate_content_config_model(
config: types.GenerateContentConfigOrDict,
) -> types.GenerateContentConfig:
if isinstance(config, dict):
return types.GenerateContentConfig(**config)
else:
return config
def format_destination(
src: str,
config: Optional[types.CreateBatchJobConfigOrDict] = None,
) -> types.CreateBatchJobConfig:
"""Formats the destination uri based on the source uri."""
config = (
types._CreateBatchJobParameters(config=config).config
or types.CreateBatchJobConfig()
)
unique_name = None
if not config.display_name:
unique_name = _common.timestamped_unique_name()
config.display_name = f'genai_batch_job_{unique_name}'
if not config.dest:
if src.startswith('gs://') and src.endswith('.jsonl'):
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
# uri prefix will be "gs://bucket/path/to/src/dest".
config.dest = f'{src[:-6]}/dest'
elif src.startswith('bq://'):
# If source uri is "bq://project.dataset.src", then the destination
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
unique_name = unique_name or _common.timestamped_unique_name()
config.dest = f'{src}_dest_{unique_name}'
else:
raise ValueError(f'Unsupported source: {src}')
return config
def get_function_map(
config: Optional[types.GenerateContentConfigOrDict] = None,
is_caller_method_async: bool = False,
) -> dict[str, Callable[..., Any]]:
"""Returns a function map from the config."""
function_map: dict[str, Callable[..., Any]] = {}
if not config:
return function_map
config_model = _create_generate_content_config_model(config)
if config_model.tools:
for tool in config_model.tools:
if callable(tool):
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
raise errors.UnsupportedFunctionError(
f'Function {tool.__name__} is a coroutine function, which is not'
' supported for automatic function calling. Please manually'
f' invoke {tool.__name__} to get the function response.'
)
function_map[tool.__name__] = tool
return function_map
def convert_number_values_for_dict_function_call_args(
args: dict[str, Any],
) -> dict[str, Any]:
"""Converts float values in dict with no decimal to integers."""
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}
def convert_number_values_for_function_call_args(
args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
"""Converts float values with no decimal to integers."""
if isinstance(args, float) and args.is_integer():
return int(args)
if isinstance(args, dict):
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}
if isinstance(args, list):
return [
convert_number_values_for_function_call_args(value) for value in args
]
return args
def is_annotation_pydantic_model(annotation: Any) -> bool:
try:
return inspect.isclass(annotation) and issubclass(
annotation, pydantic.BaseModel
)
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
# results with versions above. for example, inspect.isclass(dict[str, int]) is
# True in 3.10 and below but False in 3.11 and above.
except TypeError:
return False
def convert_if_exist_pydantic_model(
value: Any, annotation: Any, param_name: str, func_name: str
) -> Any:
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
try:
return annotation(**value)
except pydantic.ValidationError as e:
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} is not compatible with parameter annotation'
f' {annotation}, due to error {e}'
)
if isinstance(value, list) and get_origin(annotation) == list:
item_type = get_args(annotation)[0]
return [
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
for item in value
]
if isinstance(value, dict) and get_origin(annotation) == dict:
_, value_type = get_args(annotation)
return {
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
for k, v in value.items()
}
# example 1: typing.Union[int, float]
# example 2: int | float equivalent to UnionType[int, float]
if get_origin(annotation) in (Union, UnionType):
for arg in get_args(annotation):
if (
(get_args(arg) and get_origin(arg) is list)
or isinstance(value, arg)
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
):
try:
return convert_if_exist_pydantic_model(
value, arg, param_name, func_name
)
# do not raise here because there could be multiple pydantic model types
# in the union type.
except pydantic.ValidationError:
continue
# if none of the union type is matched, raise error
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} cannot be converted to parameter annotation'
f' {annotation}.'
)
# the only exception for value and annotation type to be different is int and
# float. see convert_number_values_for_function_call_args function for context
if isinstance(value, int) and annotation is float:
return value
if not isinstance(value, annotation):
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function {func_name} from'
f' function call part because function call argument value {value} is'
f' not compatible with parameter annotation {annotation}.'
)
return value
def convert_argument_from_function(
args: dict[str, Any], function: Callable[..., Any]
) -> dict[str, Any]:
signature = inspect.signature(function)
func_name = function.__name__
converted_args = {}
for param_name, param in signature.parameters.items():
if param_name in args:
converted_args[param_name] = convert_if_exist_pydantic_model(
args[param_name],
param.annotation,
param_name,
func_name,
)
return converted_args
def invoke_function_from_dict_args(
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
) -> Any:
converted_args = convert_argument_from_function(args, function_to_invoke)
try:
return function_to_invoke(**converted_args)
except Exception as e:
raise errors.FunctionInvocationError(
f'Failed to invoke function {function_to_invoke.__name__} with'
f' converted arguments {converted_args} from model returned function'
f' call argument {args} because of error {e}'
)
async def invoke_function_from_dict_args_async(
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
) -> Any:
converted_args = convert_argument_from_function(args, function_to_invoke)
try:
return await function_to_invoke(**converted_args)
except Exception as e:
raise errors.FunctionInvocationError(
f'Failed to invoke function {function_to_invoke.__name__} with'
f' converted arguments {converted_args} from model returned function'
f' call argument {args} because of error {e}'
)
def get_function_response_parts(
response: types.GenerateContentResponse,
function_map: dict[str, Callable[..., Any]],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
if (
response.candidates is not None
and isinstance(response.candidates[0].content, types.Content)
and response.candidates[0].content.parts is not None
):
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
if func_name is not None and part.function_call.args is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
part.function_call.args
)
func_response: dict[str, Any]
try:
func_response = {
'result': invoke_function_from_dict_args(args, func)
}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)
func_response_parts.append(func_response_part)
return func_response_parts
async def get_function_response_parts_async(
response: types.GenerateContentResponse,
function_map: dict[str, Callable[..., Any]],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
if (
response.candidates is not None
and isinstance(response.candidates[0].content, types.Content)
and response.candidates[0].content.parts is not None
):
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
if func_name is not None and part.function_call.args is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
part.function_call.args
)
func_response: dict[str, Any]
try:
if inspect.iscoroutinefunction(func):
func_response = {
'result': await invoke_function_from_dict_args_async(args, func)
}
else:
func_response = {
'result': invoke_function_from_dict_args(args, func)
}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)
func_response_parts.append(func_response_part)
return func_response_parts
def should_disable_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
"""Returns whether automatic function calling is enabled."""
if not config:
return False
config_model = _create_generate_content_config_model(config)
# If max_remote_calls is less or equal to 0, warn and disable AFC.
if (
config_model
and config_model.automatic_function_calling
and config_model.automatic_function_calling.maximum_remote_calls
is not None
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
):
logger.warning(
'max_remote_calls in automatic_function_calling_config'
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
' less than or equal to 0. Disabling automatic function calling.'
' Please set max_remote_calls to a positive integer.'
)
return True
# Default to enable AFC if not specified.
if (
not config_model.automatic_function_calling
or config_model.automatic_function_calling.disable is None
):
return False
if (
config_model.automatic_function_calling.disable
and config_model.automatic_function_calling.maximum_remote_calls
is not None
# exclude the case where max_remote_calls is set to 10 by default.
and 'maximum_remote_calls'
in config_model.automatic_function_calling.model_fields_set
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
):
logger.warning(
'`automatic_function_calling.disable` is set to `True`. And'
' `automatic_function_calling.maximum_remote_calls` is a'
' positive number'
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
' Disabling automatic function calling. If you want to enable'
' automatic function calling, please set'
' `automatic_function_calling.disable` to `False` or leave it unset,'
' and set `automatic_function_calling.maximum_remote_calls` to a'
' positive integer or leave'
' `automatic_function_calling.maximum_remote_calls` unset.'
)
return config_model.automatic_function_calling.disable
def get_max_remote_calls_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> int:
if not config:
return _DEFAULT_MAX_REMOTE_CALLS_AFC
"""Returns the remaining remote calls for automatic function calling."""
if should_disable_afc(config):
raise ValueError(
'automatic function calling is not enabled, but SDK is trying to get'
' max remote calls.'
)
config_model = _create_generate_content_config_model(config)
if (
not config_model.automatic_function_calling
or config_model.automatic_function_calling.maximum_remote_calls is None
):
return _DEFAULT_MAX_REMOTE_CALLS_AFC
return int(config_model.automatic_function_calling.maximum_remote_calls)
def should_append_afc_history(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
if not config:
return True
config_model = _create_generate_content_config_model(config)
if not config_model.automatic_function_calling:
return True
return not config_model.automatic_function_calling.ignore_call_history

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,597 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Replay API client."""
import base64
import copy
import datetime
import inspect
import io
import json
import os
import re
from typing import Any, Literal, Optional, Union
import google.auth
from requests.exceptions import HTTPError
from . import errors
from ._api_client import BaseApiClient
from ._api_client import HttpRequest
from ._api_client import HttpResponse
from ._common import BaseModel
from .types import HttpOptions, HttpOptionsOrDict
from .types import GenerateVideosOperation
def _redact_version_numbers(version_string: str) -> str:
"""Redacts version numbers in the form x.y.z from a string."""
return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
def _redact_language_label(language_label: str) -> str:
"""Removed because replay requests are used for all languages."""
return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label)
def _redact_request_headers(headers: dict[str, str]) -> dict[str, str]:
"""Redacts headers that should not be recorded."""
redacted_headers = {}
for header_name, header_value in headers.items():
if header_name.lower() == 'x-goog-api-key':
redacted_headers[header_name] = '{REDACTED}'
elif header_name.lower() == 'user-agent':
redacted_headers[header_name] = _redact_language_label(
_redact_version_numbers(header_value)
)
elif header_name.lower() == 'x-goog-api-client':
redacted_headers[header_name] = _redact_language_label(
_redact_version_numbers(header_value)
)
elif header_name.lower() == 'x-goog-user-project':
continue
elif header_name.lower() == 'authorization':
continue
else:
redacted_headers[header_name] = header_value
return redacted_headers
def _redact_request_url(url: str) -> str:
# Redact all the url parts before the resource name, so the test can work
# against any project, location, version, or whether it's EasyGCP.
result = re.sub(
r'.*/projects/[^/]+/locations/[^/]+/',
'{VERTEX_URL_PREFIX}/',
url,
)
result = re.sub(
r'.*-aiplatform.googleapis.com/[^/]+/',
'{VERTEX_URL_PREFIX}/',
result,
)
result = re.sub(
r'.*aiplatform.googleapis.com/[^/]+/',
'{VERTEX_URL_PREFIX}/',
result,
)
result = re.sub(
r'https://generativelanguage.googleapis.com/[^/]+',
'{MLDEV_URL_PREFIX}',
result,
)
return result
def _redact_project_location_path(path: str) -> str:
# Redact a field in the request that is known to vary based on project and
# location.
if 'projects/' in path and 'locations/' in path:
result = re.sub(
r'projects/[^/]+/locations/[^/]+/',
'{PROJECT_AND_LOCATION_PATH}/',
path,
)
return result
else:
return path
def _redact_request_body(body: dict[str, object]) -> None:
"""Redacts fields in the request body in place."""
for key, value in body.items():
if isinstance(value, str):
body[key] = _redact_project_location_path(value)
def redact_http_request(http_request: HttpRequest) -> None:
http_request.headers = _redact_request_headers(http_request.headers)
http_request.url = _redact_request_url(http_request.url)
if not isinstance(http_request.data, bytes):
_redact_request_body(http_request.data)
def _current_file_path_and_line() -> str:
"""Prints the current file path and line number."""
current_frame = inspect.currentframe()
if (
current_frame is not None
and current_frame.f_back is not None
and current_frame.f_back.f_back is not None
):
frame = current_frame.f_back.f_back
filepath = inspect.getfile(frame)
lineno = frame.f_lineno
return f'File: {filepath}, Line: {lineno}'
return ''
def _debug_print(message: str) -> None:
print(
'DEBUG (test',
os.environ.get('PYTEST_CURRENT_TEST'),
')',
_current_file_path_and_line(),
':\n ',
message,
)
class ReplayRequest(BaseModel):
"""Represents a single request in a replay."""
method: str
url: str
headers: dict[str, str]
body_segments: list[dict[str, object]]
class ReplayResponse(BaseModel):
"""Represents a single response in a replay."""
status_code: int = 200
headers: dict[str, str]
body_segments: list[dict[str, object]]
byte_segments: Optional[list[bytes]] = None
sdk_response_segments: list[dict[str, object]]
def model_post_init(self, __context: Any) -> None:
# Remove headers that are not deterministic so the replay files don't change
# every time they are recorded.
self.headers.pop('Date', None)
self.headers.pop('Server-Timing', None)
class ReplayInteraction(BaseModel):
"""Represents a single interaction, request and response in a replay."""
request: ReplayRequest
response: ReplayResponse
class ReplayFile(BaseModel):
"""Represents a recorded session."""
replay_id: str
interactions: list[ReplayInteraction]
class ReplayApiClient(BaseApiClient):
"""For integration testing, send recorded response or records a response."""
def __init__(
self,
mode: Literal['record', 'replay', 'auto', 'api'],
replay_id: str,
replays_directory: Optional[str] = None,
vertexai: bool = False,
api_key: Optional[str] = None,
credentials: Optional[google.auth.credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
http_options: Optional[HttpOptions] = None,
):
super().__init__(
vertexai=vertexai,
api_key=api_key,
credentials=credentials,
project=project,
location=location,
http_options=http_options,
)
self.replays_directory = replays_directory
if not self.replays_directory:
self.replays_directory = os.environ.get(
'GOOGLE_GENAI_REPLAYS_DIRECTORY', None
)
# Valid replay modes are replay-only or record-and-replay.
self.replay_session: Union[ReplayFile, None] = None
self._mode = mode
self._replay_id = replay_id
def initialize_replay_session(self, replay_id: str) -> None:
self._replay_id = replay_id
self._initialize_replay_session()
def _get_replay_file_path(self) -> str:
return self._generate_file_path_from_replay_id(
self.replays_directory, self._replay_id
)
def _should_call_api(self) -> bool:
return self._mode in ['record', 'api'] or (
self._mode == 'auto'
and not os.path.isfile(self._get_replay_file_path())
)
def _should_update_replay(self) -> bool:
return self._should_call_api() and self._mode != 'api'
def _initialize_replay_session_if_not_loaded(self) -> None:
if not self.replay_session:
self._initialize_replay_session()
def _initialize_replay_session(self) -> None:
_debug_print('Test is using replay id: ' + self._replay_id)
self._replay_index = 0
self._sdk_response_index = 0
replay_file_path = self._get_replay_file_path()
# This should not be triggered from the constructor.
replay_file_exists = os.path.isfile(replay_file_path)
if self._mode == 'replay' and not replay_file_exists:
raise ValueError(
'Replay files do not exist for replay id: ' + self._replay_id
)
if self._mode in ['replay', 'auto'] and replay_file_exists:
with open(replay_file_path, 'r') as f:
self.replay_session = ReplayFile.model_validate(json.loads(f.read()))
if self._should_update_replay():
self.replay_session = ReplayFile(
replay_id=self._replay_id, interactions=[]
)
def _generate_file_path_from_replay_id(self, replay_directory: Optional[str], replay_id: str) -> str:
session_parts = replay_id.split('/')
if len(session_parts) < 3:
raise ValueError(
f'{replay_id}: Session ID must be in the format of'
' module/function/[vertex|mldev]'
)
if replay_directory is None:
path_parts = []
else:
path_parts = [replay_directory]
path_parts.extend(session_parts)
return os.path.join(*path_parts) + '.json'
def close(self) -> None:
if not self._should_update_replay() or not self.replay_session:
return
replay_file_path = self._get_replay_file_path()
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
with open(replay_file_path, 'w') as f:
f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
self.replay_session = None
def _record_interaction(
self,
http_request: HttpRequest,
http_response: Union[HttpResponse, errors.APIError, bytes],
) -> None:
if not self._should_update_replay():
return
redact_http_request(http_request)
request = ReplayRequest(
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
body_segments=[http_request.data],
)
if isinstance(http_response, HttpResponse):
response = ReplayResponse(
headers=dict(http_response.headers),
body_segments=list(http_response.segments()),
byte_segments=[
seg[:100] + b'...' for seg in http_response.byte_segments()
],
status_code=http_response.status_code,
sdk_response_segments=[],
)
elif isinstance(http_response, errors.APIError):
response = ReplayResponse(
headers=dict(http_response.response.headers),
body_segments=[http_response._to_replay_record()],
status_code=http_response.code,
sdk_response_segments=[],
)
elif isinstance(http_response, bytes):
response = ReplayResponse(
headers={},
body_segments=[],
byte_segments=[http_response],
sdk_response_segments=[],
)
else:
raise ValueError(
'Unsupported http_response type: ' + str(type(http_response))
)
if self.replay_session is None:
raise ValueError('No replay session found.')
self.replay_session.interactions.append(
ReplayInteraction(request=request, response=response)
)
def _match_request(
self,
http_request: HttpRequest,
interaction: ReplayInteraction,
) -> None:
assert http_request.url == interaction.request.url
assert http_request.headers == interaction.request.headers, (
'Request headers mismatch:\n'
f'Actual: {http_request.headers}\n'
f'Expected: {interaction.request.headers}'
)
assert http_request.method == interaction.request.method
# Sanitize the request body, rewrite any fields that vary.
request_data_copy = copy.deepcopy(http_request.data)
# Both the request and recorded request must be redacted before comparing
# so that the comparison is fair.
if not isinstance(request_data_copy, bytes):
_redact_request_body(request_data_copy)
actual_request_body = [request_data_copy]
expected_request_body = interaction.request.body_segments
assert actual_request_body == expected_request_body, (
'Request body mismatch:\n'
f'Actual: {actual_request_body}\n'
f'Expected: {expected_request_body}'
)
def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse:
redact_http_request(http_request)
if self.replay_session is None:
raise ValueError('No replay session found.')
interaction = self.replay_session.interactions[self._replay_index]
# Replay is on the right side of the assert so the diff makes more sense.
self._match_request(http_request, interaction)
self._replay_index += 1
self._sdk_response_index = 0
errors.APIError.raise_for_response(interaction.response)
return HttpResponse(
headers=interaction.response.headers,
response_stream=[
json.dumps(segment)
for segment in interaction.response.body_segments
],
byte_stream=interaction.response.byte_segments,
)
def _verify_response(self, response_model: BaseModel) -> None:
if self._mode == 'api':
return
if not self.replay_session:
raise ValueError('No replay session found.')
# replay_index is advanced in _build_response_from_replay, so we need to -1.
interaction = self.replay_session.interactions[self._replay_index - 1]
if self._should_update_replay():
if isinstance(response_model, list):
response_model = response_model[0]
if response_model and 'http_headers' in response_model.model_fields:
response_model.http_headers.pop('Date', None) # type: ignore[attr-defined]
interaction.response.sdk_response_segments.append(
response_model.model_dump(exclude_none=True)
)
return
if isinstance(response_model, list):
response_model = response_model[0]
print('response_model: ', response_model.model_dump(exclude_none=True))
if isinstance(response_model, GenerateVideosOperation):
actual = response_model.model_dump(
exclude={'result'}, exclude_none=True, mode='json'
)
else:
actual = response_model.model_dump(exclude_none=True, mode='json')
expected = interaction.response.sdk_response_segments[
self._sdk_response_index
]
assert (
actual == expected
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
self._sdk_response_index += 1
def _request(
self,
http_request: HttpRequest,
stream: bool = False,
) -> HttpResponse:
self._initialize_replay_session_if_not_loaded()
if self._should_call_api():
_debug_print('api mode request: %s' % http_request)
try:
result = super()._request(http_request, stream)
except errors.APIError as e:
self._record_interaction(http_request, e)
raise e
if stream:
result_segments = []
for segment in result.segments():
result_segments.append(json.dumps(segment))
result = HttpResponse(result.headers, result_segments)
self._record_interaction(http_request, result)
# Need to return a RecordedResponse that rebuilds the response
# segments since the stream has been consumed.
else:
self._record_interaction(http_request, result)
_debug_print('api mode result: %s' % result.json)
return result
else:
return self._build_response_from_replay(http_request)
async def _async_request(
self,
http_request: HttpRequest,
stream: bool = False,
) -> HttpResponse:
self._initialize_replay_session_if_not_loaded()
if self._should_call_api():
_debug_print('api mode request: %s' % http_request)
try:
result = await super()._async_request(http_request, stream)
except errors.APIError as e:
self._record_interaction(http_request, e)
raise e
if stream:
result_segments = []
async for segment in result.async_segments():
result_segments.append(json.dumps(segment))
result = HttpResponse(result.headers, result_segments)
self._record_interaction(http_request, result)
# Need to return a RecordedResponse that rebuilds the response
# segments since the stream has been consumed.
else:
self._record_interaction(http_request, result)
_debug_print('api mode result: %s' % result.json)
return result
else:
return self._build_response_from_replay(http_request)
def upload_file(
self,
file_path: Union[str, io.IOBase],
upload_url: str,
upload_size: int,
*,
http_options: Optional[HttpOptionsOrDict] = None,
) -> HttpResponse:
if isinstance(file_path, io.IOBase):
offset = file_path.tell()
content = file_path.read()
file_path.seek(offset, os.SEEK_SET)
request = HttpRequest(
method='POST',
url='',
data={'bytes': base64.b64encode(content).decode('utf-8')},
headers={}
)
else:
request = HttpRequest(
method='POST', url='', data={'file_path': file_path}, headers={}
)
if self._should_call_api():
result: Union[str, HttpResponse]
try:
result = super().upload_file(
file_path, upload_url, upload_size, http_options=http_options
)
except HTTPError as e:
result = HttpResponse(
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
)
result.status_code = e.response.status_code
raise e
self._record_interaction(request, result)
return result
else:
return self._build_response_from_replay(request)
async def async_upload_file(
self,
file_path: Union[str, io.IOBase],
upload_url: str,
upload_size: int,
*,
http_options: Optional[HttpOptionsOrDict] = None,
) -> HttpResponse:
if isinstance(file_path, io.IOBase):
offset = file_path.tell()
content = file_path.read()
file_path.seek(offset, os.SEEK_SET)
request = HttpRequest(
method='POST',
url='',
data={'bytes': base64.b64encode(content).decode('utf-8')},
headers={},
)
else:
request = HttpRequest(
method='POST', url='', data={'file_path': file_path}, headers={}
)
if self._should_call_api():
result: HttpResponse
try:
result = await super().async_upload_file(
file_path, upload_url, upload_size, http_options=http_options
)
except HTTPError as e:
result = HttpResponse(
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
)
result.status_code = e.response.status_code
raise e
self._record_interaction(request, result)
return result
else:
return self._build_response_from_replay(request)
def download_file(
self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
) -> Union[HttpResponse, bytes, Any]:
self._initialize_replay_session_if_not_loaded()
request = self._build_request(
'get', path=path, request_dict={}, http_options=http_options
)
if self._should_call_api():
try:
result = super().download_file(path, http_options=http_options)
except HTTPError as e:
result = HttpResponse(
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
)
result.status_code = e.response.status_code
raise e
self._record_interaction(request, result)
return result
else:
return self._build_response_from_replay(request).byte_stream[0]
async def async_download_file(
self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
) -> Any:
self._initialize_replay_session_if_not_loaded()
request = self._build_request(
'get', path=path, request_dict={}, http_options=http_options
)
if self._should_call_api():
try:
result = await super().async_download_file(
path, http_options=http_options
)
except HTTPError as e:
result = HttpResponse(
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
)
result.status_code = e.response.status_code
raise e
self._record_interaction(request, result)
return result
else:
return self._build_response_from_replay(request).byte_stream[0]

View File

@@ -0,0 +1,149 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import time
from unittest.mock import MagicMock, patch
import pytest
from .api_client import BaseApiClient
@patch('genai.api_client.BaseApiClient._build_request')
@patch('genai.api_client.BaseApiClient._request')
def test_request_streamed_non_blocking(mock_request, mock_build_request):
api_client = BaseApiClient(api_key='test_api_key')
http_method = 'GET'
path = 'test/path'
request_dict = {'key': 'value'}
mock_http_request = MagicMock()
mock_build_request.return_value = mock_http_request
def delayed_segments():
chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
for chunk in chunks:
time.sleep(0.1) # 100ms delay
yield chunk
mock_response = MagicMock()
mock_response.segments.side_effect = delayed_segments
mock_request.return_value = mock_response
chunks = []
start_time = time.time()
for chunk in api_client.request_streamed(http_method, path, request_dict):
chunks.append(chunk)
assert len(chunks) <= 3
end_time = time.time()
mock_build_request.assert_called_once_with(
http_method, path, request_dict, None
)
mock_request.assert_called_once_with(mock_http_request, stream=True)
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
assert end_time - start_time > 0.3
@patch('genai.api_client.BaseApiClient._build_request')
@patch('genai.api_client.BaseApiClient._async_request')
@pytest.mark.asyncio
async def test_async_request(mock_async_request, mock_build_request):
api_client = ApiClient(api_key='test_api_key')
http_method = 'GET'
path = 'test/path'
request_dict = {'key': 'value'}
mock_http_request = MagicMock()
mock_build_request.return_value = mock_http_request
class MockResponse:
def __init__(self, text):
self.text = text
async def delayed_response(http_request, stream):
await asyncio.sleep(0.1) # 100ms delay
return MockResponse('value')
mock_async_request.side_effect = delayed_response
async_coroutine1 = api_client.async_request(http_method, path, request_dict)
async_coroutine2 = api_client.async_request(http_method, path, request_dict)
async_coroutine3 = api_client.async_request(http_method, path, request_dict)
start_time = time.time()
results = await asyncio.gather(
async_coroutine1, async_coroutine2, async_coroutine3
)
end_time = time.time()
mock_build_request.assert_called_with(http_method, path, request_dict, None)
assert mock_build_request.call_count == 3
mock_async_request.assert_called_with(
http_request=mock_http_request, stream=False
)
assert mock_async_request.call_count == 3
assert results == ['value', 'value', 'value']
assert 0.1 <= end_time - start_time < 0.15
@patch('genai.api_client.BaseApiClient._build_request')
@patch('genai.api_client.BaseApiClient._async_request')
@pytest.mark.asyncio
async def test_async_request_streamed_non_blocking(
mock_async_request, mock_build_request
):
api_client = ApiClient(api_key='test_api_key')
http_method = 'GET'
path = 'test/path'
request_dict = {'key': 'value'}
mock_http_request = MagicMock()
mock_build_request.return_value = mock_http_request
class MockResponse:
def __init__(self, segments):
self._segments = segments
# should mock async generator here but source code combines sync and async streaming in one segment method.
# TODO: fix the above
def segments(self):
for segment in self._segments:
time.sleep(0.1) # 100ms delay
yield segment
async def delayed_response(http_request, stream):
return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'])
mock_async_request.side_effect = delayed_response
chunks = []
start_time = time.time()
async for chunk in await api_client.async_request_streamed(
http_method, path, request_dict
):
chunks.append(chunk)
assert len(chunks) <= 3
end_time = time.time()
mock_build_request.assert_called_once_with(
http_method, path, request_dict, None
)
mock_async_request.assert_called_once_with(
http_request=mock_http_request, stream=True
)
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
assert end_time - start_time > 0.3

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,532 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Iterator
import sys
from typing import AsyncIterator, Awaitable, Optional, Union, get_args
from . import _transformers as t
from . import types
from .models import AsyncModels, Models
from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
def _validate_content(content: Content) -> bool:
if not content.parts:
return False
for part in content.parts:
if part == Part():
return False
if part.text is not None and part.text == "":
return False
return True
def _validate_contents(contents: list[Content]) -> bool:
if not contents:
return False
for content in contents:
if not _validate_content(content):
return False
return True
def _validate_response(response: GenerateContentResponse) -> bool:
if not response.candidates:
return False
if not response.candidates[0].content:
return False
return _validate_content(response.candidates[0].content)
def _extract_curated_history(
comprehensive_history: list[Content],
) -> list[Content]:
"""Extracts the curated (valid) history from a comprehensive history.
The comprehensive history contains all turns (user input and model responses),
including any invalid or rejected model outputs. This function filters
that history to return only the valid turns.
A "turn" starts with one user input (a single content) and then follows by
corresponding model response (which may consist of multiple contents).
Turns are assumed to alternate: user input, model output, user input, model
output, etc.
Args:
comprehensive_history: A list representing the complete chat history.
Including invalid turns.
Returns:
curated history, which is a list of valid turns.
"""
if not comprehensive_history:
return []
curated_history = []
length = len(comprehensive_history)
i = 0
current_input = comprehensive_history[i]
if current_input.role != "user":
raise ValueError("History must start with a user turn.")
while i < length:
if comprehensive_history[i].role not in ["user", "model"]:
raise ValueError(
f"Role must be user or model, but got {comprehensive_history[i].role}"
)
if comprehensive_history[i].role == "user":
current_input = comprehensive_history[i]
i += 1
else:
current_output = []
is_valid = True
while i < length and comprehensive_history[i].role == "model":
current_output.append(comprehensive_history[i])
if is_valid and not _validate_content(comprehensive_history[i]):
is_valid = False
i += 1
if is_valid:
curated_history.append(current_input)
curated_history.extend(current_output)
return curated_history
class _BaseChat:
"""Base chat session."""
def __init__(
self,
*,
model: str,
config: Optional[GenerateContentConfigOrDict] = None,
history: list[ContentOrDict],
):
self._model = model
self._config = config
content_models = []
for content in history:
if not isinstance(content, Content):
content_model = Content.model_validate(content)
else:
content_model = content
content_models.append(content_model)
self._comprehensive_history = content_models
"""Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
"""
self._curated_history = _extract_curated_history(content_models)
"""Curated history is the set of valid turns that will be used in the subsequent send requests.
"""
def record_history(
self,
user_input: Content,
model_output: list[Content],
automatic_function_calling_history: list[Content],
is_valid: bool,
) -> None:
"""Records the chat history.
Maintaining both comprehensive and curated histories.
Args:
user_input: The user's input content.
model_output: A list of `Content` from the model's response. This can be
an empty list if the model produced no output.
automatic_function_calling_history: A list of `Content` representing the
history of automatic function calls, including the user input as the
first entry.
is_valid: A boolean flag indicating whether the current model output is
considered valid.
"""
input_contents = (
# Because the AFC input contains the entire curated chat history in
# addition to the new user input, we need to truncate the AFC history
# to deduplicate the existing chat history.
automatic_function_calling_history[len(self._curated_history):]
if automatic_function_calling_history
else [user_input]
)
# Appends an empty content when model returns empty response, so that the
# history is always alternating between user and model.
output_contents = (
model_output if model_output else [Content(role="model", parts=[])]
)
self._comprehensive_history.extend(input_contents)
self._comprehensive_history.extend(output_contents)
if is_valid:
self._curated_history.extend(input_contents)
self._curated_history.extend(output_contents)
def get_history(self, curated: bool = False) -> list[Content]:
"""Returns the chat history.
Args:
curated: A boolean flag indicating whether to return the curated (valid)
history or the comprehensive (all turns) history. Defaults to False
(returns the comprehensive history).
Returns:
A list of `Content` objects representing the chat history.
"""
if curated:
return self._curated_history
else:
return self._comprehensive_history
def _is_part_type(
contents: Union[list[PartUnionDict], PartUnionDict],
) -> TypeGuard[t.ContentType]:
if isinstance(contents, list):
return all(_is_part_type(part) for part in contents)
else:
allowed_part_types = get_args(types.PartUnion)
if type(contents) in allowed_part_types:
return True
else:
# Some images don't pass isinstance(item, PIL.Image.Image)
# For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
return True
return False
class Chat(_BaseChat):
"""Chat session."""
def __init__(
self,
*,
modules: Models,
model: str,
config: Optional[GenerateContentConfigOrDict] = None,
history: list[ContentOrDict],
):
self._modules = modules
super().__init__(
model=model,
config=config,
history=history,
)
def send_message(
self,
message: Union[list[PartUnionDict], PartUnionDict],
config: Optional[GenerateContentConfigOrDict] = None,
) -> GenerateContentResponse:
"""Sends the conversation history with the additional message and returns the model's response.
Args:
message: The message to send to the model.
config: Optional config to override the default Chat config for this
request.
Returns:
The model's response.
Usage:
.. code-block:: python
chat = client.chats.create(model='gemini-1.5-flash')
response = chat.send_message('tell me a story')
"""
if not _is_part_type(message):
raise ValueError(
f"Message must be a valid part type: {types.PartUnion} or"
f" {types.PartUnionDict}, got {type(message)}"
)
input_content = t.t_content(self._modules._api_client, message)
response = self._modules.generate_content(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
)
model_output = (
[response.candidates[0].content]
if response.candidates and response.candidates[0].content
else []
)
automatic_function_calling_history = (
response.automatic_function_calling_history
if response.automatic_function_calling_history
else []
)
self.record_history(
user_input=input_content,
model_output=model_output,
automatic_function_calling_history=automatic_function_calling_history,
is_valid=_validate_response(response),
)
return response
def send_message_stream(
self,
message: Union[list[PartUnionDict], PartUnionDict],
config: Optional[GenerateContentConfigOrDict] = None,
) -> Iterator[GenerateContentResponse]:
"""Sends the conversation history with the additional message and yields the model's response in chunks.
Args:
message: The message to send to the model.
config: Optional config to override the default Chat config for this
request.
Yields:
The model's response in chunks.
Usage:
.. code-block:: python
chat = client.chats.create(model='gemini-1.5-flash')
for chunk in chat.send_message_stream('tell me a story'):
print(chunk.text)
"""
if not _is_part_type(message):
raise ValueError(
f"Message must be a valid part type: {types.PartUnion} or"
f" {types.PartUnionDict}, got {type(message)}"
)
input_content = t.t_content(self._modules._api_client, message)
output_contents = []
finish_reason = None
is_valid = True
chunk = None
if isinstance(self._modules, Models):
for chunk in self._modules.generate_content_stream(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
):
if not _validate_response(chunk):
is_valid = False
if chunk.candidates and chunk.candidates[0].content:
output_contents.append(chunk.candidates[0].content)
if chunk.candidates and chunk.candidates[0].finish_reason:
finish_reason = chunk.candidates[0].finish_reason
yield chunk
automatic_function_calling_history = (
chunk.automatic_function_calling_history
if chunk.automatic_function_calling_history
else []
)
self.record_history(
user_input=input_content,
model_output=output_contents,
automatic_function_calling_history=automatic_function_calling_history,
is_valid=is_valid
and output_contents is not None
and finish_reason is not None,
)
class Chats:
"""A util class to create chat sessions."""
def __init__(self, modules: Models):
self._modules = modules
def create(
self,
*,
model: str,
config: Optional[GenerateContentConfigOrDict] = None,
history: Optional[list[ContentOrDict]] = None,
) -> Chat:
"""Creates a new chat session.
Args:
model: The model to use for the chat.
config: The configuration to use for the generate content request.
history: The history to use for the chat.
Returns:
A new chat session.
"""
return Chat(
modules=self._modules,
model=model,
config=config,
history=history if history else [],
)
class AsyncChat(_BaseChat):
"""Async chat session."""
def __init__(
self,
*,
modules: AsyncModels,
model: str,
config: Optional[GenerateContentConfigOrDict] = None,
history: list[ContentOrDict],
):
self._modules = modules
super().__init__(
model=model,
config=config,
history=history,
)
async def send_message(
self,
message: Union[list[PartUnionDict], PartUnionDict],
config: Optional[GenerateContentConfigOrDict] = None,
) -> GenerateContentResponse:
"""Sends the conversation history with the additional message and returns model's response.
Args:
message: The message to send to the model.
config: Optional config to override the default Chat config for this
request.
Returns:
The model's response.
Usage:
.. code-block:: python
chat = client.aio.chats.create(model='gemini-1.5-flash')
response = await chat.send_message('tell me a story')
"""
if not _is_part_type(message):
raise ValueError(
f"Message must be a valid part type: {types.PartUnion} or"
f" {types.PartUnionDict}, got {type(message)}"
)
input_content = t.t_content(self._modules._api_client, message)
response = await self._modules.generate_content(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
)
model_output = (
[response.candidates[0].content]
if response.candidates and response.candidates[0].content
else []
)
automatic_function_calling_history = (
response.automatic_function_calling_history
if response.automatic_function_calling_history
else []
)
self.record_history(
user_input=input_content,
model_output=model_output,
automatic_function_calling_history=automatic_function_calling_history,
is_valid=_validate_response(response),
)
return response
async def send_message_stream(
self,
message: Union[list[PartUnionDict], PartUnionDict],
config: Optional[GenerateContentConfigOrDict] = None,
) -> AsyncIterator[GenerateContentResponse]:
"""Sends the conversation history with the additional message and yields the model's response in chunks.
Args:
message: The message to send to the model.
config: Optional config to override the default Chat config for this
request.
Yields:
The model's response in chunks.
Usage:
.. code-block:: python
chat = client.aio.chats.create(model='gemini-1.5-flash')
async for chunk in await chat.send_message_stream('tell me a story'):
print(chunk.text)
"""
if not _is_part_type(message):
raise ValueError(
f"Message must be a valid part type: {types.PartUnion} or"
f" {types.PartUnionDict}, got {type(message)}"
)
input_content = t.t_content(self._modules._api_client, message)
async def async_generator(): # type: ignore[no-untyped-def]
output_contents = []
finish_reason = None
is_valid = True
chunk = None
async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined]
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
):
if not _validate_response(chunk):
is_valid = False
if chunk.candidates and chunk.candidates[0].content:
output_contents.append(chunk.candidates[0].content)
if chunk.candidates and chunk.candidates[0].finish_reason:
finish_reason = chunk.candidates[0].finish_reason
yield chunk
if not output_contents or finish_reason is None:
is_valid = False
self.record_history(
user_input=input_content,
model_output=output_contents,
automatic_function_calling_history=chunk.automatic_function_calling_history if chunk.automatic_function_calling_history else [],
is_valid=is_valid,
)
return async_generator() # type: ignore[no-untyped-call, no-any-return]
class AsyncChats:
"""A util class to create async chat sessions."""
def __init__(self, modules: AsyncModels):
self._modules = modules
def create(
self,
*,
model: str,
config: Optional[GenerateContentConfigOrDict] = None,
history: Optional[list[ContentOrDict]] = None,
) -> AsyncChat:
"""Creates a new chat session.
Args:
model: The model to use for the chat.
config: The configuration to use for the generate content request.
history: The history to use for the chat.
Returns:
A new chat session.
"""
return AsyncChat(
modules=self._modules,
model=model,
config=config,
history=history if history else [],
)

View File

@@ -0,0 +1,290 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Optional, Union
import google.auth
import pydantic
from ._api_client import BaseApiClient
from ._replay_api_client import ReplayApiClient
from .batches import AsyncBatches, Batches
from .caches import AsyncCaches, Caches
from .chats import AsyncChats, Chats
from .files import AsyncFiles, Files
from .live import AsyncLive
from .models import AsyncModels, Models
from .operations import AsyncOperations, Operations
from .tunings import AsyncTunings, Tunings
from .types import HttpOptions, HttpOptionsDict
class AsyncClient:
"""Client for making asynchronous (non-blocking) requests."""
def __init__(self, api_client: BaseApiClient):
self._api_client = api_client
self._models = AsyncModels(self._api_client)
self._tunings = AsyncTunings(self._api_client)
self._caches = AsyncCaches(self._api_client)
self._batches = AsyncBatches(self._api_client)
self._files = AsyncFiles(self._api_client)
self._live = AsyncLive(self._api_client)
self._operations = AsyncOperations(self._api_client)
@property
def models(self) -> AsyncModels:
return self._models
@property
def tunings(self) -> AsyncTunings:
return self._tunings
@property
def caches(self) -> AsyncCaches:
return self._caches
@property
def batches(self) -> AsyncBatches:
return self._batches
@property
def chats(self) -> AsyncChats:
return AsyncChats(modules=self.models)
@property
def files(self) -> AsyncFiles:
return self._files
@property
def live(self) -> AsyncLive:
return self._live
@property
def operations(self) -> AsyncOperations:
return self._operations
class DebugConfig(pydantic.BaseModel):
"""Configuration options that change client network behavior when testing."""
client_mode: Optional[str] = pydantic.Field(
default_factory=lambda: os.getenv('GOOGLE_GENAI_CLIENT_MODE', None)
)
replays_directory: Optional[str] = pydantic.Field(
default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAYS_DIRECTORY', None)
)
replay_id: Optional[str] = pydantic.Field(
default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAY_ID', None)
)
class Client:
"""Client for making synchronous requests.
Use this client to make a request to the Gemini Developer API or Vertex AI
API and then wait for the response.
To initialize the client, provide the required arguments either directly
or by using environment variables. Gemini API users and Vertex AI users in
express mode can provide API key by providing input argument
`api_key="your-api-key"` or by defining `GOOGLE_API_KEY="your-api-key"` as an
environment variable
Vertex AI API users can provide inputs argument as `vertexai=True,
project="your-project-id", location="us-central1"` or by defining
`GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT` and
`GOOGLE_CLOUD_LOCATION` environment variables.
Attributes:
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
use for authentication. Applies to the Gemini Developer API only.
vertexai: Indicates whether the client should use the Vertex AI
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
Applies to the Vertex AI API only.
credentials: The credentials to use for authentication when calling the
Vertex AI APIs. Credentials can be obtained from environment variables and
default credentials. For more information, see
`Set up Application Default Credentials
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
Applies to the Vertex AI API only.
project: The `Google Cloud project ID <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
use for quota. Can be obtained from environment variables (for example,
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
location: The `location <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
to send API requests to (for example, ``us-central1``). Can be obtained
from environment variables. Applies to the Vertex AI API only.
debug_config: Config settings that control network behavior of the client.
This is typically used when running test code.
http_options: Http options to use for the client. These options will be
applied to all requests made by the client. Example usage:
`client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`.
Usage for the Gemini Developer API:
.. code-block:: python
from google import genai
client = genai.Client(api_key='my-api-key')
Usage for the Vertex AI API:
.. code-block:: python
from google import genai
client = genai.Client(
vertexai=True, project='my-project-id', location='us-central1'
)
"""
def __init__(
self,
*,
vertexai: Optional[bool] = None,
api_key: Optional[str] = None,
credentials: Optional[google.auth.credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
debug_config: Optional[DebugConfig] = None,
http_options: Optional[Union[HttpOptions, HttpOptionsDict]] = None,
):
"""Initializes the client.
Args:
vertexai (bool): Indicates whether the client should use the Vertex AI
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
Applies to the Vertex AI API only.
api_key (str): The `API key
<https://ai.google.dev/gemini-api/docs/api-key>`_ to use for
authentication. Applies to the Gemini Developer API only.
credentials (google.auth.credentials.Credentials): The credentials to use
for authentication when calling the Vertex AI APIs. Credentials can be
obtained from environment variables and default credentials. For more
information, see `Set up Application Default Credentials
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
Applies to the Vertex AI API only.
project (str): The `Google Cloud project ID
<https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
use for quota. Can be obtained from environment variables (for example,
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
location (str): The `location
<https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
to send API requests to (for example, ``us-central1``). Can be obtained
from environment variables. Applies to the Vertex AI API only.
debug_config (DebugConfig): Config settings that control network behavior
of the client. This is typically used when running test code.
http_options (Union[HttpOptions, HttpOptionsDict]): Http options to use
for the client.
"""
self._debug_config = debug_config or DebugConfig()
if isinstance(http_options, dict):
http_options = HttpOptions(**http_options)
self._api_client = self._get_api_client(
vertexai=vertexai,
api_key=api_key,
credentials=credentials,
project=project,
location=location,
debug_config=self._debug_config,
http_options=http_options,
)
self._aio = AsyncClient(self._api_client)
self._models = Models(self._api_client)
self._tunings = Tunings(self._api_client)
self._caches = Caches(self._api_client)
self._batches = Batches(self._api_client)
self._files = Files(self._api_client)
self._operations = Operations(self._api_client)
@staticmethod
def _get_api_client(
vertexai: Optional[bool] = None,
api_key: Optional[str] = None,
credentials: Optional[google.auth.credentials.Credentials] = None,
project: Optional[str] = None,
location: Optional[str] = None,
debug_config: Optional[DebugConfig] = None,
http_options: Optional[HttpOptions] = None,
) -> BaseApiClient:
if debug_config and debug_config.client_mode in [
'record',
'replay',
'auto',
]:
return ReplayApiClient(
mode=debug_config.client_mode, # type: ignore[arg-type]
replay_id=debug_config.replay_id, # type: ignore[arg-type]
replays_directory=debug_config.replays_directory,
vertexai=vertexai, # type: ignore[arg-type]
api_key=api_key,
credentials=credentials,
project=project,
location=location,
http_options=http_options,
)
return BaseApiClient(
vertexai=vertexai,
api_key=api_key,
credentials=credentials,
project=project,
location=location,
http_options=http_options,
)
@property
def chats(self) -> Chats:
return Chats(modules=self.models)
@property
def aio(self) -> AsyncClient:
return self._aio
@property
def models(self) -> Models:
return self._models
@property
def tunings(self) -> Tunings:
return self._tunings
@property
def caches(self) -> Caches:
return self._caches
@property
def batches(self) -> Batches:
return self._batches
@property
def files(self) -> Files:
return self._files
@property
def operations(self) -> Operations:
return self._operations
@property
def vertexai(self) -> bool:
"""Returns whether the client is using the Vertex AI API."""
return self._api_client.vertexai or False

View File

@@ -0,0 +1,163 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Error classes for the GenAI SDK."""
from typing import Any, Optional, TYPE_CHECKING, Union
import httpx
import json
if TYPE_CHECKING:
from .replay_api_client import ReplayResponse
class APIError(Exception):
"""General errors raised by the GenAI API."""
code: int
response: Union['ReplayResponse', httpx.Response]
status: Optional[str] = None
message: Optional[str] = None
def __init__(
self,
code: int,
response_json: Any,
response: Optional[Union['ReplayResponse', httpx.Response]] = None,
):
self.response = response
self.details = response_json
self.message = self._get_message(response_json)
self.status = self._get_status(response_json)
self.code = code if code else self._get_code(response_json)
super().__init__(f'{self.code} {self.status}. {self.details}')
def _get_status(self, response_json: Any) -> Any:
return response_json.get(
'status', response_json.get('error', {}).get('status', None)
)
def _get_message(self, response_json: Any) -> Any:
return response_json.get(
'message', response_json.get('error', {}).get('message', None)
)
def _get_code(self, response_json: Any) -> Any:
return response_json.get(
'code', response_json.get('error', {}).get('code', None)
)
def _to_replay_record(self) -> dict[str, Any]:
"""Returns a dictionary representation of the error for replay recording.
details is not included since it may expose internal information in the
replay file.
"""
return {
'error': {
'code': self.code,
'message': self.message,
'status': self.status,
}
}
@classmethod
def raise_for_response(
cls, response: Union['ReplayResponse', httpx.Response]
) -> None:
"""Raises an error with detailed error message if the response has an error status."""
if response.status_code == 200:
return
if isinstance(response, httpx.Response):
try:
response.read()
response_json = response.json()
except json.decoder.JSONDecodeError:
message = response.text
response_json = {
'message': message,
'status': response.reason_phrase,
}
else:
response_json = response.body_segments[0].get('error', {})
status_code = response.status_code
if 400 <= status_code < 500:
raise ClientError(status_code, response_json, response)
elif 500 <= status_code < 600:
raise ServerError(status_code, response_json, response)
else:
raise cls(status_code, response_json, response)
@classmethod
async def raise_for_async_response(
cls, response: Union['ReplayResponse', httpx.Response]
) -> None:
"""Raises an error with detailed error message if the response has an error status."""
if response.status_code == 200:
return
if isinstance(response, httpx.Response):
try:
await response.aread()
response_json = response.json()
except json.decoder.JSONDecodeError:
message = response.text
response_json = {
'message': message,
'status': response.reason_phrase,
}
else:
response_json = response.body_segments[0].get('error', {})
status_code = response.status_code
if 400 <= status_code < 500:
raise ClientError(status_code, response_json, response)
elif 500 <= status_code < 600:
raise ServerError(status_code, response_json, response)
else:
raise cls(status_code, response_json, response)
class ClientError(APIError):
"""Client error raised by the GenAI API."""
pass
class ServerError(APIError):
"""Server error raised by the GenAI API."""
pass
class UnknownFunctionCallArgumentError(ValueError):
"""Raised when the function call argument cannot be converted to the parameter annotation."""
pass
class UnsupportedFunctionError(ValueError):
"""Raised when the function is not supported."""
class FunctionInvocationError(ValueError):
"""Raised when the function cannot be invoked with the given arguments."""
pass
class ExperimentalWarning(Warning):
"""Warning for experimental features."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,984 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""[Preview] Live API client."""
import asyncio
import base64
import contextlib
import json
import logging
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
import warnings
import google.auth
import pydantic
from websockets import ConnectionClosed
from . import _api_module
from . import _common
from . import _transformers as t
from . import client
from . import types
from ._api_client import BaseApiClient
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
from . import _live_converters as live_converters
from .models import _Content_to_mldev
from .models import _Content_to_vertex
try:
from websockets.asyncio.client import ClientConnection
from websockets.asyncio.client import connect
except ModuleNotFoundError:
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
from websockets.client import ClientConnection # type: ignore
from websockets.client import connect # type: ignore
logger = logging.getLogger('google_genai.live')
_FUNCTION_RESPONSE_REQUIRES_ID = (
'FunctionResponse request must have an `id` field from the'
' response of a ToolCall.FunctionalCalls in Google AI.'
)
class AsyncSession:
"""[Preview] AsyncSession."""
def __init__(
self, api_client: BaseApiClient, websocket: ClientConnection
):
self._api_client = api_client
self._ws = websocket
async def send(
self,
*,
input: Optional[
Union[
types.ContentListUnion,
types.ContentListUnionDict,
types.LiveClientContentOrDict,
types.LiveClientRealtimeInputOrDict,
types.LiveClientToolResponseOrDict,
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
]
] = None,
end_of_turn: Optional[bool] = False,
) -> None:
"""[Deprecated] Send input to the model.
> **Warning**: This method is deprecated and will be removed in a future
version (not before Q3 2025). Please use one of the more specific methods:
`send_client_content`, `send_realtime_input`, or `send_tool_response`
instead.
The method will send the input request to the server.
Args:
input: The input request to the model.
end_of_turn: Whether the input is the last message in a turn.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
async with client.aio.live.connect(model='...') as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
warnings.warn(
'The `session.send` method is deprecated and will be removed in a '
'future version (not before Q3 2025).\n'
'Please use one of the more specific methods: `send_client_content`, '
'`send_realtime_input`, or `send_tool_response` instead.',
DeprecationWarning,
stacklevel=2,
)
client_message = self._parse_client_message(input, end_of_turn)
await self._ws.send(json.dumps(client_message))
async def send_client_content(
self,
*,
turns: Optional[
Union[
types.Content,
types.ContentDict,
list[Union[types.Content, types.ContentDict]]
]
] = None,
turn_complete: bool = True,
) -> None:
"""Send non-realtime, turn based content to the model.
There are two ways to send messages to the live API:
`send_client_content` and `send_realtime_input`.
`send_client_content` messages are added to the model context **in order**.
Having a conversation using `send_client_content` messages is roughly
equivalent to using the `Chat.send_message_stream` method, except that the
state of the `chat` history is stored on the API server.
Because of `send_client_content`'s order guarantee, the model cannot
respond as quickly to `send_client_content` messages as to
`send_realtime_input` messages. This makes the biggest difference when
sending objects that have significant preprocessing time (typically images).
The `send_client_content` message sends a list of `Content` objects,
which has more options than the `media:Blob` sent by `send_realtime_input`.
The main use-cases for `send_client_content` over `send_realtime_input` are:
- Prefilling a conversation context (including sending anything that can't
be represented as a realtime message), before starting a realtime
conversation.
- Conducting a non-realtime conversation, similar to `client.chat`, using
the live api.
Caution: Interleaving `send_client_content` and `send_realtime_input`
in the same conversation is not recommended and can lead to unexpected
results.
Args:
turns: A `Content` object or list of `Content` objects (or equivalent
dicts).
turn_complete: if true (the default) the model will reply immediately. If
false, the model will wait for you to send additional client_content,
and will not return until you send `turn_complete=True`.
Example:
```
import google.genai
from google.genai import types
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
async with client.aio.live.connect(
model=MODEL_NAME,
config={"response_modalities": ["TEXT"]}
) as session:
await session.send_client_content(
turns=types.Content(
role='user',
parts=[types.Part(text="Hello world!")]))
async for msg in session.receive():
if msg.text:
print(msg.text)
```
"""
client_content = t.t_client_content(turns, turn_complete)
if self._api_client.vertexai:
client_content_dict = live_converters._LiveClientContent_to_vertex(
api_client=self._api_client, from_object=client_content
)
else:
client_content_dict = live_converters._LiveClientContent_to_mldev(
api_client=self._api_client, from_object=client_content
)
await self._ws.send(json.dumps({'client_content': client_content_dict}))
async def send_realtime_input(
self,
*,
media: Optional[types.BlobImageUnionDict] = None,
audio: Optional[types.BlobOrDict] = None,
audio_stream_end: Optional[bool] = None,
video: Optional[types.BlobImageUnionDict] = None,
text: Optional[str] = None,
activity_start: Optional[types.ActivityStartOrDict] = None,
activity_end: Optional[types.ActivityEndOrDict] = None,
) -> None:
"""Send realtime input to the model, only send one argument per call.
Use `send_realtime_input` for realtime audio chunks and video
frames(images).
With `send_realtime_input` the api will respond to audio automatically
based on voice activity detection (VAD).
`send_realtime_input` is optimized for responsivness at the expense of
deterministic ordering. Audio and video tokens are added to the
context when they become available.
Args:
media: A `Blob`-like object, the realtime media to send.
Example:
```
from pathlib import Path
from google import genai
from google.genai import types
import PIL.Image
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
async with client.aio.live.connect(
model=MODEL_NAME,
config={"response_modalities": ["TEXT"]},
) as session:
await session.send_realtime_input(
media=PIL.Image.open('image.jpg'))
audio_bytes = Path('audio.pcm').read_bytes()
await session.send_realtime_input(
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
async for msg in session.receive():
if msg.text is not None:
print(f'{msg.text}')
```
"""
kwargs:dict[str, Any] = {}
if media is not None:
kwargs['media'] = media
if audio is not None:
kwargs['audio'] = audio
if audio_stream_end is not None:
kwargs['audio_stream_end'] = audio_stream_end
if video is not None:
kwargs['video'] = video
if text is not None:
kwargs['text'] = text
if activity_start is not None:
kwargs['activity_start'] = activity_start
if activity_end is not None:
kwargs['activity_end'] = activity_end
if len(kwargs) != 1:
raise ValueError(
f'Only one argument can be set, got {len(kwargs)}:'
f' {list(kwargs.keys())}'
)
realtime_input = types.LiveSendRealtimeInputParameters.model_validate(
kwargs
)
if self._api_client.vertexai:
realtime_input_dict = (
live_converters._LiveSendRealtimeInputParameters_to_vertex(
api_client=self._api_client, from_object=realtime_input
)
)
else:
realtime_input_dict = (
live_converters._LiveSendRealtimeInputParameters_to_mldev(
api_client=self._api_client, from_object=realtime_input
)
)
realtime_input_dict = _common.convert_to_dict(realtime_input_dict)
realtime_input_dict = _common.encode_unserializable_types(
realtime_input_dict
)
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
async def send_tool_response(
self,
*,
function_responses: Union[
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
],
) -> None:
"""Send a tool response to the session.
Use `send_tool_response` to reply to `LiveServerToolCall` messages
from the server.
To set the available tools, use the `config.tools` argument
when you connect to the session (`client.live.connect`).
Args:
function_responses: A `FunctionResponse`-like object or list of
`FunctionResponse`-like objects.
Example:
```
from google import genai
from google.genai import types
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
config = {
"tools": tools,
"response_modalities": ['TEXT']
}
async with client.aio.live.connect(
model='models/gemini-2.0-flash-live-001',
config=config
) as session:
prompt = "Turn on the lights please"
await session.send_client_content(
turns={"parts": [{'text': prompt}]}
)
async for chunk in session.receive():
if chunk.server_content:
if chunk.text is not None:
print(chunk.text)
elif chunk.tool_call:
print(chunk.tool_call)
print('_'*80)
function_response=types.FunctionResponse(
name='turn_on_the_lights',
response={'result': 'ok'},
id=chunk.tool_call.function_calls[0].id,
)
print(function_response)
await session.send_tool_response(
function_responses=function_response
)
print('_'*80)
"""
tool_response = t.t_tool_response(function_responses)
if self._api_client.vertexai:
tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
api_client=self._api_client, from_object=tool_response
)
else:
tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
api_client=self._api_client, from_object=tool_response
)
for response in tool_response_dict.get('functionResponses', []):
if response.get('id') is None:
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
"""Receive model responses from the server.
The method will yield the model responses from the server. The returned
responses will represent a complete model turn. When the returned message
is function call, user must call `send` with the function response to
continue the turn.
Yields:
The model responses from the server.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
async with client.aio.live.connect(model='...') as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
# TODO(b/365983264) Handle intermittent issues for the user.
while result := await self._receive():
if result.server_content and result.server_content.turn_complete:
yield result
break
yield result
async def start_stream(
self, *, stream: AsyncIterator[bytes], mime_type: str
) -> AsyncIterator[types.LiveServerMessage]:
"""[Deprecated] Start a live session from a data stream.
> **Warning**: This method is deprecated and will be removed in a future
version (not before Q2 2025). Please use one of the more specific methods:
`send_client_content`, `send_realtime_input`, or `send_tool_response`
instead.
The interaction terminates when the input stream is complete.
This method will start two async tasks. One task will be used to send the
input stream to the model and the other task will be used to receive the
responses from the model.
Args:
stream: An iterator that yields the model response.
mime_type: The MIME type of the data in the stream.
Yields:
The audio bytes received from the model and server response messages.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
config = {'response_modalities': ['AUDIO']}
async def audio_stream():
stream = read_audio()
for data in stream:
yield data
async with client.aio.live.connect(model='...', config=config) as session:
for audio in session.start_stream(stream = audio_stream(),
mime_type = 'audio/pcm'):
play_audio_chunk(audio.data)
"""
warnings.warn(
'Setting `AsyncSession.start_stream` is deprecated, '
'and will be removed in a future release (not before Q3 2025). '
'Please use the `receive`, and `send_realtime_input`, methods instead.',
DeprecationWarning,
stacklevel=4,
)
stop_event = asyncio.Event()
# Start the send loop. When stream is complete stop_event is set.
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
recv_task = None
while not stop_event.is_set():
try:
recv_task = asyncio.create_task(self._receive())
await asyncio.wait(
[
recv_task,
asyncio.create_task(stop_event.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
if recv_task.done():
yield recv_task.result()
# Give a chance for the send loop to process requests.
await asyncio.sleep(10**-12)
except ConnectionClosed:
break
if recv_task is not None and not recv_task.done():
recv_task.cancel()
# Wait for the task to finish (cancelled or not)
try:
await recv_task
except asyncio.CancelledError:
pass
async def _receive(self) -> types.LiveServerMessage:
parameter_model = types.LiveServerMessage()
try:
raw_response = await self._ws.recv(decode=False)
except TypeError:
raw_response = await self._ws.recv() # type: ignore[assignment]
if raw_response:
try:
response = json.loads(raw_response)
except json.decoder.JSONDecodeError:
raise ValueError(f'Failed to parse response: {raw_response!r}')
else:
response = {}
if self._api_client.vertexai:
response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
else:
response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
return types.LiveServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
async def _send_loop(
self,
data_stream: AsyncIterator[bytes],
mime_type: str,
stop_event: asyncio.Event,
) -> None:
async for data in data_stream:
model_input = types.LiveClientRealtimeInput(
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
)
await self.send(input=model_input)
# Give a chance for the receive loop to process responses.
await asyncio.sleep(10**-12)
# Give a chance for the receiver to process the last response.
stop_event.set()
def _parse_client_message(
self,
input: Optional[
Union[
types.ContentListUnion,
types.ContentListUnionDict,
types.LiveClientContentOrDict,
types.LiveClientRealtimeInputOrDict,
types.LiveClientToolResponseOrDict,
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
]
] = None,
end_of_turn: Optional[bool] = False,
) -> types.LiveClientMessageDict:
formatted_input: Any = input
if not input:
logging.info('No input provided. Assume it is the end of turn.')
return {'client_content': {'turn_complete': True}}
if isinstance(input, str):
formatted_input = [input]
elif isinstance(input, dict) and 'data' in input:
try:
blob_input = types.Blob(**input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
if (
isinstance(blob_input, types.Blob)
and isinstance(blob_input.data, bytes)
):
formatted_input = [
blob_input.model_dump(mode='json', exclude_none=True)
]
elif isinstance(input, types.Blob):
formatted_input = [input]
elif isinstance(input, dict) and 'name' in input and 'response' in input:
# ToolResponse.FunctionResponse
if not (self._api_client.vertexai) and 'id' not in input:
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
formatted_input = [input]
if isinstance(formatted_input, Sequence) and any(
isinstance(c, dict) and 'name' in c and 'response' in c
for c in formatted_input
):
# ToolResponse.FunctionResponse
function_responses_input = []
for item in formatted_input:
if isinstance(item, dict):
try:
function_response_input = types.FunctionResponse(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
if (
function_response_input.id is None
and not self._api_client.vertexai
):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
else:
function_response_dict = function_response_input.model_dump(
exclude_none=True, mode='json'
)
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get(
'id'
)
function_responses_input.append(function_response_typeddict)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=function_responses_input
)
)
elif isinstance(formatted_input, Sequence) and any(
isinstance(c, str) for c in formatted_input
):
to_object: dict[str, Any] = {}
content_input_parts: list[types.PartUnion] = []
for item in formatted_input:
if isinstance(item, get_args(types.PartUnion)):
content_input_parts.append(item)
if self._api_client.vertexai:
contents = [
_Content_to_vertex(self._api_client, item, to_object)
for item in t.t_contents(self._api_client, content_input_parts)
]
else:
contents = [
_Content_to_mldev(self._api_client, item, to_object)
for item in t.t_contents(self._api_client, content_input_parts)
]
content_dict_list: list[types.ContentDict] = []
for item in contents:
try:
content_input = types.Content(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
content_dict_list.append(
types.ContentDict(
parts=content_input.model_dump(exclude_none=True, mode='json')[
'parts'
],
role=content_input.role,
)
)
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=content_dict_list, turn_complete=end_of_turn
)
)
elif isinstance(formatted_input, Sequence):
if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
pass
elif any(isinstance(b, types.Blob) for b in formatted_input):
formatted_input = [
b.model_dump(exclude_none=True, mode='json')
for b in formatted_input
]
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=formatted_input
)
)
elif isinstance(formatted_input, dict):
if 'content' in formatted_input or 'turns' in formatted_input:
# TODO(b/365983264) Add validation checks for content_update input_dict.
if 'turns' in formatted_input:
content_turns = formatted_input['turns']
else:
content_turns = formatted_input['content']
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=content_turns,
turn_complete=formatted_input.get('turn_complete'),
)
)
elif 'media_chunks' in formatted_input:
try:
realtime_input = types.LiveClientRealtimeInput(**formatted_input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=realtime_input.model_dump(
exclude_none=True, mode='json'
)['media_chunks']
)
)
elif 'function_responses' in formatted_input:
try:
tool_response_input = types.LiveClientToolResponse(**formatted_input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=tool_response_input.model_dump(
exclude_none=True, mode='json'
)['function_responses']
)
)
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
elif isinstance(formatted_input, types.LiveClientRealtimeInput):
realtime_input_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=realtime_input_dict.get('media_chunks')
)
)
if (
client_message['realtime_input'] is not None
and client_message['realtime_input']['media_chunks'] is not None
and isinstance(
client_message['realtime_input']['media_chunks'][0]['data'], bytes
)
):
formatted_media_chunks: list[types.BlobDict] = []
for item in client_message['realtime_input']['media_chunks']:
if isinstance(item, dict):
try:
blob_input = types.Blob(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
if (
isinstance(blob_input, types.Blob)
and isinstance(blob_input.data, bytes)
and blob_input.data is not None
):
formatted_media_chunks.append(
types.BlobDict(
data=base64.b64decode(blob_input.data),
mime_type=blob_input.mime_type,
)
)
client_message['realtime_input'][
'media_chunks'
] = formatted_media_chunks
elif isinstance(formatted_input, types.LiveClientContent):
client_content_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=client_content_dict.get('turns'),
turn_complete=client_content_dict.get('turn_complete'),
)
)
elif isinstance(formatted_input, types.LiveClientToolResponse):
# ToolResponse.FunctionResponse
if (
not (self._api_client.vertexai)
and formatted_input.function_responses is not None
and not (formatted_input.function_responses[0].id)
):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=formatted_input.model_dump(
exclude_none=True, mode='json'
).get('function_responses')
)
)
elif isinstance(formatted_input, types.FunctionResponse):
if not (self._api_client.vertexai) and not (formatted_input.id):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
function_response_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get('id')
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=[function_response_typeddict]
)
)
elif isinstance(formatted_input, Sequence) and isinstance(
formatted_input[0], types.FunctionResponse
):
if not (self._api_client.vertexai) and not (formatted_input[0].id):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
function_response_list: list[types.FunctionResponseDict] = []
for item in formatted_input:
function_response_dict = item.model_dump(exclude_none=True, mode='json')
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get('id')
function_response_list.append(function_response_typeddict)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=function_response_list
)
)
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
return client_message
async def close(self) -> None:
# Close the websocket connection.
await self._ws.close()
class AsyncLive(_api_module.BaseModule):
"""[Preview] AsyncLive."""
@contextlib.asynccontextmanager
async def connect(
self,
*,
model: str,
config: Optional[types.LiveConnectConfigOrDict] = None,
) -> AsyncIterator[AsyncSession]:
"""[Preview] Connect to the live server.
Note: the live API is currently in preview.
Usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
config = {}
async with client.aio.live.connect(model='...', config=config) as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
base_url = self._api_client._websocket_base_url()
if isinstance(base_url, bytes):
base_url = base_url.decode('utf-8')
transformed_model = t.t_model(self._api_client, model)
parameter_model = _t_live_connect_config(self._api_client, config)
if self._api_client.api_key:
api_key = self._api_client.api_key
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
headers = self._api_client._http_options.headers
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_mldev(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True)
)
)
del request_dict['config']
setv(request_dict, ['setup', 'model'], transformed_model)
request = json.dumps(request_dict)
else:
# Get bearer token through Application Default Credentials.
creds, _ = google.auth.default( # type: ignore[no-untyped-call]
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
# creds.valid is False, and creds.token is None
# Need to refresh credentials to populate those
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
creds.refresh(auth_req)
bearer_token = creds.token
headers = self._api_client._http_options.headers
if headers is not None:
headers.update({
'Authorization': 'Bearer {}'.format(bearer_token),
})
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
location = self._api_client.location
project = self._api_client.project
if transformed_model.startswith('publishers/'):
transformed_model = (
f'projects/{project}/locations/{location}/' + transformed_model
)
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True)
)
)
del request_dict['config']
if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
request = json.dumps(request_dict)
try:
async with connect(uri, additional_headers=headers) as ws:
await ws.send(request)
logger.info(await ws.recv(decode=False))
yield AsyncSession(api_client=self._api_client, websocket=ws)
except TypeError:
# Try with the older websockets API
async with connect(uri, extra_headers=headers) as ws:
await ws.send(request)
logger.info(await ws.recv())
yield AsyncSession(api_client=self._api_client, websocket=ws)
def _t_live_connect_config(
api_client: BaseApiClient,
config: Optional[types.LiveConnectConfigOrDict],
) -> types.LiveConnectConfig:
# Ensure the config is a LiveConnectConfig.
if config is None:
parameter_model = types.LiveConnectConfig()
elif isinstance(config, dict):
if getv(config, ['system_instruction']) is not None:
converted_system_instruction = t.t_content(
api_client, getv(config, ['system_instruction'])
)
else:
converted_system_instruction = None
parameter_model = types.LiveConnectConfig(**config)
parameter_model.system_instruction = converted_system_instruction
else:
if config.system_instruction is None:
system_instruction = None
else:
system_instruction = t.t_content(
api_client, getv(config, ['system_instruction'])
)
parameter_model = config
parameter_model.system_instruction = system_instruction
if parameter_model.generation_config is not None:
warnings.warn(
'Setting `LiveConnectConfig.generation_config` is deprecated, '
'please set the fields on `LiveConnectConfig` directly. This will '
'become an error in a future version (not before Q3 2025)',
DeprecationWarning,
stacklevel=4,
)
return parameter_model

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,651 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
import logging
from typing import Any, Optional, Union
from urllib.parse import urlencode
from . import _api_module
from . import _common
from . import _transformers as t
from . import types
from ._api_client import BaseApiClient
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
logger = logging.getLogger('google_genai.operations')
def _GetOperationParameters_to_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['operation_name']) is not None:
setv(
to_object,
['_url', 'operationName'],
getv(from_object, ['operation_name']),
)
if getv(from_object, ['config']) is not None:
setv(to_object, ['config'], getv(from_object, ['config']))
return to_object
def _GetOperationParameters_to_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['operation_name']) is not None:
setv(
to_object,
['_url', 'operationName'],
getv(from_object, ['operation_name']),
)
if getv(from_object, ['config']) is not None:
setv(to_object, ['config'], getv(from_object, ['config']))
return to_object
def _FetchPredictOperationParameters_to_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['operation_name']) is not None:
setv(to_object, ['operationName'], getv(from_object, ['operation_name']))
if getv(from_object, ['resource_name']) is not None:
setv(
to_object,
['_url', 'resourceName'],
getv(from_object, ['resource_name']),
)
if getv(from_object, ['config']) is not None:
setv(to_object, ['config'], getv(from_object, ['config']))
return to_object
def _Video_from_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['video', 'uri']) is not None:
setv(to_object, ['uri'], getv(from_object, ['video', 'uri']))
if getv(from_object, ['video', 'encodedVideo']) is not None:
setv(
to_object,
['video_bytes'],
t.t_bytes(api_client, getv(from_object, ['video', 'encodedVideo'])),
)
if getv(from_object, ['encoding']) is not None:
setv(to_object, ['mime_type'], getv(from_object, ['encoding']))
return to_object
def _GeneratedVideo_from_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['_self']) is not None:
setv(
to_object,
['video'],
_Video_from_mldev(api_client, getv(from_object, ['_self']), to_object),
)
return to_object
def _GenerateVideosResponse_from_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['generatedSamples']) is not None:
setv(
to_object,
['generated_videos'],
[
_GeneratedVideo_from_mldev(api_client, item, to_object)
for item in getv(from_object, ['generatedSamples'])
],
)
if getv(from_object, ['raiMediaFilteredCount']) is not None:
setv(
to_object,
['rai_media_filtered_count'],
getv(from_object, ['raiMediaFilteredCount']),
)
if getv(from_object, ['raiMediaFilteredReasons']) is not None:
setv(
to_object,
['rai_media_filtered_reasons'],
getv(from_object, ['raiMediaFilteredReasons']),
)
return to_object
def _GenerateVideosOperation_from_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['name']) is not None:
setv(to_object, ['name'], getv(from_object, ['name']))
if getv(from_object, ['metadata']) is not None:
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
if getv(from_object, ['done']) is not None:
setv(to_object, ['done'], getv(from_object, ['done']))
if getv(from_object, ['error']) is not None:
setv(to_object, ['error'], getv(from_object, ['error']))
if getv(from_object, ['response', 'generateVideoResponse']) is not None:
setv(
to_object,
['response'],
_GenerateVideosResponse_from_mldev(
api_client,
getv(from_object, ['response', 'generateVideoResponse']),
to_object,
),
)
if getv(from_object, ['response', 'generateVideoResponse']) is not None:
setv(
to_object,
['result'],
_GenerateVideosResponse_from_mldev(
api_client,
getv(from_object, ['response', 'generateVideoResponse']),
to_object,
),
)
return to_object
def _Video_from_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['gcsUri']) is not None:
setv(to_object, ['uri'], getv(from_object, ['gcsUri']))
if getv(from_object, ['bytesBase64Encoded']) is not None:
setv(
to_object,
['video_bytes'],
t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])),
)
if getv(from_object, ['mimeType']) is not None:
setv(to_object, ['mime_type'], getv(from_object, ['mimeType']))
return to_object
def _GeneratedVideo_from_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['_self']) is not None:
setv(
to_object,
['video'],
_Video_from_vertex(api_client, getv(from_object, ['_self']), to_object),
)
return to_object
def _GenerateVideosResponse_from_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['videos']) is not None:
setv(
to_object,
['generated_videos'],
[
_GeneratedVideo_from_vertex(api_client, item, to_object)
for item in getv(from_object, ['videos'])
],
)
if getv(from_object, ['raiMediaFilteredCount']) is not None:
setv(
to_object,
['rai_media_filtered_count'],
getv(from_object, ['raiMediaFilteredCount']),
)
if getv(from_object, ['raiMediaFilteredReasons']) is not None:
setv(
to_object,
['rai_media_filtered_reasons'],
getv(from_object, ['raiMediaFilteredReasons']),
)
return to_object
def _GenerateVideosOperation_from_vertex(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['name']) is not None:
setv(to_object, ['name'], getv(from_object, ['name']))
if getv(from_object, ['metadata']) is not None:
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
if getv(from_object, ['done']) is not None:
setv(to_object, ['done'], getv(from_object, ['done']))
if getv(from_object, ['error']) is not None:
setv(to_object, ['error'], getv(from_object, ['error']))
if getv(from_object, ['response']) is not None:
setv(
to_object,
['response'],
_GenerateVideosResponse_from_vertex(
api_client, getv(from_object, ['response']), to_object
),
)
if getv(from_object, ['response']) is not None:
setv(
to_object,
['result'],
_GenerateVideosResponse_from_vertex(
api_client, getv(from_object, ['response']), to_object
),
)
return to_object
class Operations(_api_module.BaseModule):
def _get_videos_operation(
self,
*,
operation_name: str,
config: Optional[types.GetOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
parameter_model = types._GetOperationParameters(
operation_name=operation_name,
config=config,
)
request_url_dict: Optional[dict[str, str]]
if self._api_client.vertexai:
request_dict = _GetOperationParameters_to_vertex(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{operationName}'.format_map(request_url_dict)
else:
path = '{operationName}'
else:
request_dict = _GetOperationParameters_to_mldev(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{operationName}'.format_map(request_url_dict)
else:
path = '{operationName}'
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)
http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)
response_dict = self._api_client.request(
'get', path, request_dict, http_options
)
if self._api_client.vertexai:
response_dict = _GenerateVideosOperation_from_vertex(
self._api_client, response_dict
)
else:
response_dict = _GenerateVideosOperation_from_mldev(
self._api_client, response_dict
)
return_value = types.GenerateVideosOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
self._api_client._verify_response(return_value)
return return_value
def _fetch_predict_videos_operation(
self,
*,
operation_name: str,
resource_name: str,
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
parameter_model = types._FetchPredictOperationParameters(
operation_name=operation_name,
resource_name=resource_name,
config=config,
)
request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError('This method is only supported in the Vertex AI client.')
else:
request_dict = _FetchPredictOperationParameters_to_vertex(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{resourceName}:fetchPredictOperation'.format_map(
request_url_dict
)
else:
path = '{resourceName}:fetchPredictOperation'
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)
http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)
response_dict = self._api_client.request(
'post', path, request_dict, http_options
)
if self._api_client.vertexai:
response_dict = _GenerateVideosOperation_from_vertex(
self._api_client, response_dict
)
return_value = types.GenerateVideosOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
self._api_client._verify_response(return_value)
return return_value
def get(
self,
operation: types.GenerateVideosOperation,
*,
config: Optional[types.GetOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
"""Gets the status of an operation."""
# Currently, only GenerateVideosOperation is supported.
# TODO(b/398040607): Support short form names
operation_name = operation.name
if not operation_name:
raise ValueError('Operation name is empty.')
# TODO(b/398233524): Cast operation types
if self._api_client.vertexai:
resource_name = operation_name.rpartition('/operations/')[0]
http_options = types.HttpOptions()
if isinstance(config, dict):
dict_options = config.get('http_options', None)
if dict_options is not None:
http_options = types.HttpOptions(**dict(dict_options))
elif isinstance(config, types.GetOperationConfig) and config is not None:
http_options = (
config.http_options
if config.http_options is not None
else types.HttpOptions()
)
fetch_operation_config = types.FetchPredictOperationConfig(
http_options=http_options
)
return self._fetch_predict_videos_operation(
operation_name=operation_name,
resource_name=resource_name,
config=fetch_operation_config,
)
else:
return self._get_videos_operation(
operation_name=operation_name,
config=config,
)
class AsyncOperations(_api_module.BaseModule):
async def _get_videos_operation(
self,
*,
operation_name: str,
config: Optional[types.GetOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
parameter_model = types._GetOperationParameters(
operation_name=operation_name,
config=config,
)
request_url_dict: Optional[dict[str, str]]
if self._api_client.vertexai:
request_dict = _GetOperationParameters_to_vertex(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{operationName}'.format_map(request_url_dict)
else:
path = '{operationName}'
else:
request_dict = _GetOperationParameters_to_mldev(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{operationName}'.format_map(request_url_dict)
else:
path = '{operationName}'
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)
http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)
response_dict = await self._api_client.async_request(
'get', path, request_dict, http_options
)
if self._api_client.vertexai:
response_dict = _GenerateVideosOperation_from_vertex(
self._api_client, response_dict
)
else:
response_dict = _GenerateVideosOperation_from_mldev(
self._api_client, response_dict
)
return_value = types.GenerateVideosOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
self._api_client._verify_response(return_value)
return return_value
async def _fetch_predict_videos_operation(
self,
*,
operation_name: str,
resource_name: str,
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
parameter_model = types._FetchPredictOperationParameters(
operation_name=operation_name,
resource_name=resource_name,
config=config,
)
request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError('This method is only supported in the Vertex AI client.')
else:
request_dict = _FetchPredictOperationParameters_to_vertex(
self._api_client, parameter_model
)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = '{resourceName}:fetchPredictOperation'.format_map(
request_url_dict
)
else:
path = '{resourceName}:fetchPredictOperation'
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)
http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options
request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)
response_dict = await self._api_client.async_request(
'post', path, request_dict, http_options
)
if self._api_client.vertexai:
response_dict = _GenerateVideosOperation_from_vertex(
self._api_client, response_dict
)
return_value = types.GenerateVideosOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
self._api_client._verify_response(return_value)
return return_value
async def get(
self,
operation: types.GenerateVideosOperation,
*,
config: Optional[types.GetOperationConfigOrDict] = None,
) -> types.GenerateVideosOperation:
"""Gets the status of an operation."""
# Currently, only GenerateVideosOperation is supported.
operation_name = operation.name
if not operation_name:
raise ValueError('Operation name is empty.')
if self._api_client.vertexai:
resource_name = operation_name.rpartition('/operations/')[0]
http_options = types.HttpOptions()
if isinstance(config, dict):
dict_options = config.get('http_options', None)
if dict_options is not None:
http_options = types.HttpOptions(**dict(dict_options))
elif isinstance(config, types.GetOperationConfig) and config is not None:
http_options = (
config.http_options
if config.http_options is not None
else types.HttpOptions()
)
fetch_operation_config = types.FetchPredictOperationConfig(
http_options=http_options
)
return await self._fetch_predict_videos_operation(
operation_name=operation_name,
resource_name=resource_name,
config=fetch_operation_config,
)
else:
return await self._get_videos_operation(
operation_name=operation_name,
config=config,
)

View File

@@ -0,0 +1,252 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pagers for the GenAI List APIs."""
# pylint: disable=protected-access
import copy
from typing import Any, AsyncIterator,Awaitable, Callable, Generic, Iterator, Literal, TypeVar
T = TypeVar('T')
PagedItem = Literal[
'batch_jobs', 'models', 'tuning_jobs', 'files', 'cached_contents'
]
class _BasePager(Generic[T]):
"""Base pager class for iterating through paginated results."""
def _init_page(
self,
name: PagedItem,
request: Callable[..., Any],
response: Any,
config: Any,
) -> None:
self._name = name
self._request = request
self._page = getattr(response, self._name) or []
self._idx = 0
if not config:
request_config = {}
elif isinstance(config, dict):
request_config = copy.deepcopy(config)
else:
request_config = dict(config)
request_config['page_token'] = getattr(response, 'next_page_token')
self._config = request_config
self._page_size: int = request_config.get('page_size', len(self._page))
def __init__(
self,
name: PagedItem,
request: Callable[..., Any],
response: Any,
config: Any,
):
self._init_page(name, request, response, config)
@property
def page(self) -> list[T]:
"""Returns a subset of the entire list of items.
For the number of items returned, see `pageSize()`.
Usage:
.. code-block:: python
batch_jobs_pager = client.batches.list(config={'page_size': 5})
print(f"first page: {batch_jobs_pager.page}")
# first page: [BatchJob(name='projects/./locations/./batchPredictionJobs/1
"""
return self._page
@property
def name(self) -> PagedItem:
"""Returns the type of paged item (for example, ``batch_jobs``).
Usage:
.. code-block:: python
batch_jobs_pager = client.batches.list(config={'page_size': 5})
print(f"name: {batch_jobs_pager.name}")
# name: batch_jobs
"""
return self._name
@property
def page_size(self) -> int:
"""Returns the maximum number of items fetched by the pager at one time.
Usage:
.. code-block:: python
batch_jobs_pager = client.batches.list(config={'page_size': 5})
print(f"page_size: {batch_jobs_pager.page_size}")
# page_size: 5
"""
return self._page_size
@property
def config(self) -> dict[str, Any]:
"""Returns the configuration when making the API request for the next page.
A configuration is a set of optional parameters and arguments that can be
used to customize the API request. For example, the ``page_token`` parameter
contains the token to request the next page.
Usage:
.. code-block:: python
batch_jobs_pager = client.batches.list(config={'page_size': 5})
print(f"config: {batch_jobs_pager.config}")
# config: {'page_size': 5, 'page_token': 'AMEw9yO5jnsGnZJLHSKDFHJJu'}
"""
return self._config
def __len__(self) -> int:
"""Returns the total number of items in the current page."""
return len(self.page)
def __getitem__(self, index: int) -> T:
"""Returns the item at the given index."""
return self.page[index]
def _init_next_page(self, response: Any) -> None:
"""Initializes the next page from the response.
This is an internal method that should be called by subclasses after
fetching the next page.
Args:
response: The response object from the API request.
"""
self._init_page(self.name, self._request, response, self.config)
class Pager(_BasePager[T]):
"""Pager class for iterating through paginated results."""
def __next__(self) -> T:
"""Returns the next item."""
if self._idx >= len(self):
try:
self.next_page()
except IndexError:
raise StopIteration
item = self.page[self._idx]
self._idx += 1
return item
def __iter__(self) -> Iterator[T]:
"""Returns an iterator over the items."""
self._idx = 0
return self
def next_page(self) -> list[T]:
"""Fetches the next page of items. This makes a new API request.
Usage:
.. code-block:: python
batch_jobs_pager = client.batches.list(config={'page_size': 5})
print(f"current page: {batch_jobs_pager.page}")
batch_jobs_pager.next_page()
print(f"next page: {batch_jobs_pager.page}")
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
"""
if not self.config.get('page_token'):
raise IndexError('No more pages to fetch.')
response = self._request(config=self.config)
self._init_next_page(response)
return self.page
class AsyncPager(_BasePager[T]):
"""AsyncPager class for iterating through paginated results."""
def __init__(
self,
name: PagedItem,
request: Callable[..., Awaitable[Any]],
response: Any,
config: Any,
):
super().__init__(name, request, response, config)
def __aiter__(self) -> AsyncIterator[T]:
"""Returns an async iterator over the items."""
self._idx = 0
return self
async def __anext__(self) -> T:
"""Returns the next item asynchronously."""
if self._idx >= len(self):
try:
await self.next_page()
except IndexError:
raise StopAsyncIteration
item = self.page[self._idx]
self._idx += 1
return item
async def next_page(self) -> list[T]:
"""Fetches the next page of items asynchronously.
This makes a new API request.
Returns:
The next page of items.
Raises:
IndexError: No more pages to fetch.
Usage:
.. code-block:: python
batch_jobs_pager = await client.aio.batches.list(config={'page_size': 5})
print(f"current page: {batch_jobs_pager.page}")
await batch_jobs_pager.next_page()
print(f"next page: {batch_jobs_pager.page}")
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
"""
if not self.config.get('page_token'):
raise IndexError('No more pages to fetch.')
response = await self._request(config=self.config)
self._init_next_page(response)
return self.page

View File

@@ -0,0 +1 @@
# see: https://peps.python.org/pep-0561/

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__version__ = '1.12.1' # x-release-please-version