structure saas with tools
This commit is contained in:
23
.venv/lib/python3.10/site-packages/google/genai/__init__.py
Normal file
23
.venv/lib/python3.10/site-packages/google/genai/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Google Gen AI SDK"""
|
||||
|
||||
from .client import Client
|
||||
from . import version
|
||||
|
||||
__version__ = version.__version__
|
||||
|
||||
__all__ = ['Client']
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1081
.venv/lib/python3.10/site-packages/google/genai/_api_client.py
Normal file
1081
.venv/lib/python3.10/site-packages/google/genai/_api_client.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Utilities for the API Modules of the Google Gen AI SDK."""
|
||||
|
||||
from typing import Optional
|
||||
from . import _api_client
|
||||
|
||||
|
||||
class BaseModule:
|
||||
|
||||
def __init__(self, api_client_: _api_client.BaseApiClient):
|
||||
self._api_client = api_client_
|
||||
|
||||
@property
|
||||
def vertexai(self) -> Optional[bool]:
|
||||
return self._api_client.vertexai
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
import types as builtin_types
|
||||
import typing
|
||||
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import _extra_utils
|
||||
from . import types
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
VersionedUnionType = builtin_types.UnionType
|
||||
else:
|
||||
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
_py_builtin_type_to_schema_type = {
|
||||
str: types.Type.STRING,
|
||||
int: types.Type.INTEGER,
|
||||
float: types.Type.NUMBER,
|
||||
bool: types.Type.BOOLEAN,
|
||||
list: types.Type.ARRAY,
|
||||
dict: types.Type.OBJECT,
|
||||
}
|
||||
|
||||
|
||||
def _is_builtin_primitive_or_compound(
|
||||
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
|
||||
) -> bool:
|
||||
return annotation in _py_builtin_type_to_schema_type.keys()
|
||||
|
||||
|
||||
def _is_default_value_compatible(
|
||||
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
|
||||
) -> bool:
|
||||
# None type is expected to be handled external to this function
|
||||
if _is_builtin_primitive_or_compound(annotation):
|
||||
return isinstance(default_value, annotation)
|
||||
|
||||
if (
|
||||
isinstance(annotation, _GenericAlias)
|
||||
or isinstance(annotation, builtin_types.GenericAlias)
|
||||
or isinstance(annotation, VersionedUnionType)
|
||||
):
|
||||
origin = get_origin(annotation)
|
||||
if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
|
||||
return any(
|
||||
_is_default_value_compatible(default_value, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
|
||||
if origin is dict: # type: ignore[comparison-overlap]
|
||||
return isinstance(default_value, dict)
|
||||
|
||||
if origin is list: # type: ignore[comparison-overlap]
|
||||
if not isinstance(default_value, list):
|
||||
return False
|
||||
# most tricky case, element in list is union type
|
||||
# need to apply any logic within all
|
||||
# see test case test_generic_alias_complex_array_with_default_value
|
||||
# a: typing.List[int | str | float | bool]
|
||||
# default_value: [1, 'a', 1.1, True]
|
||||
return all(
|
||||
any(
|
||||
_is_default_value_compatible(item, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
for item in default_value
|
||||
)
|
||||
|
||||
if origin is Literal: # type: ignore[comparison-overlap]
|
||||
return default_value in get_args(annotation)
|
||||
|
||||
# return False for any other unrecognized annotation
|
||||
return False
|
||||
|
||||
|
||||
def _parse_schema_from_parameter(
|
||||
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
|
||||
param: inspect.Parameter,
|
||||
func_name: str,
|
||||
) -> types.Schema:
|
||||
"""parse schema from parameter.
|
||||
|
||||
from the simplest case to the most complex case.
|
||||
"""
|
||||
schema = types.Schema()
|
||||
default_value_error_msg = (
|
||||
f'Default value {param.default} of parameter {param} of function'
|
||||
f' {func_name} is not compatible with the parameter annotation'
|
||||
f' {param.annotation}.'
|
||||
)
|
||||
if _is_builtin_primitive_or_compound(param.annotation):
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
||||
return schema
|
||||
if (
|
||||
isinstance(param.annotation, VersionedUnionType)
|
||||
# only parse simple UnionType, example int | str | float | bool
|
||||
# complex UnionType will be invoked in raise branch
|
||||
and all(
|
||||
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
||||
for arg in get_args(param.annotation)
|
||||
)
|
||||
):
|
||||
schema.type = _py_builtin_type_to_schema_type[dict]
|
||||
schema.any_of = []
|
||||
unique_types = set()
|
||||
for arg in get_args(param.annotation):
|
||||
if arg.__name__ == 'NoneType': # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
api_option,
|
||||
inspect.Parameter(
|
||||
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: list | None -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
return schema
|
||||
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
||||
param.annotation, builtin_types.GenericAlias
|
||||
):
|
||||
origin = get_origin(param.annotation)
|
||||
args = get_args(param.annotation)
|
||||
if origin is dict:
|
||||
schema.type = _py_builtin_type_to_schema_type[dict]
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
return schema
|
||||
if origin is Literal:
|
||||
if not all(isinstance(arg, str) for arg in args):
|
||||
raise ValueError(
|
||||
f'Literal type {param.annotation} must be a list of strings.'
|
||||
)
|
||||
schema.type = _py_builtin_type_to_schema_type[str]
|
||||
schema.enum = list(args)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
return schema
|
||||
if origin is list:
|
||||
schema.type = _py_builtin_type_to_schema_type[list]
|
||||
schema.items = _parse_schema_from_parameter(
|
||||
api_option,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=args[0],
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
return schema
|
||||
if origin is Union:
|
||||
schema.any_of = []
|
||||
schema.type = _py_builtin_type_to_schema_type[dict]
|
||||
unique_types = set()
|
||||
for arg in args:
|
||||
# The first check is for NoneType in Python 3.9, since the __name__
|
||||
# attribute is not available in Python 3.9
|
||||
if type(arg) is type(None) or (
|
||||
hasattr(arg, '__name__') and arg.__name__ == 'NoneType'
|
||||
): # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
api_option,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=arg,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
len(param.annotation.__args__) == 2
|
||||
and type(None) in param.annotation.__args__
|
||||
): # Optional type
|
||||
for optional_arg in param.annotation.__args__:
|
||||
if (
|
||||
hasattr(optional_arg, '__origin__')
|
||||
and optional_arg.__origin__ is list
|
||||
):
|
||||
# Optional type with list, for example Optional[list[str]]
|
||||
schema.items = schema_in_any_of.items
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not None
|
||||
and param.default is not inspect.Parameter.empty
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
return schema
|
||||
# all other generic alias will be invoked in raise branch
|
||||
if (
|
||||
# for user defined class, we only support pydantic model
|
||||
_extra_utils.is_annotation_pydantic_model(param.annotation)
|
||||
):
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
schema.default = param.default
|
||||
schema.type = _py_builtin_type_to_schema_type[dict]
|
||||
schema.properties = {}
|
||||
for field_name, field_info in param.annotation.model_fields.items():
|
||||
schema.properties[field_name] = _parse_schema_from_parameter(
|
||||
api_option,
|
||||
inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=field_info.annotation,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
schema.required = _get_required_fields(schema)
|
||||
return schema
|
||||
raise ValueError(
|
||||
f'Failed to parse the parameter {param} of function {func_name} for'
|
||||
' automatic function calling.Automatic function calling works best with'
|
||||
' simpler function signature schema,consider manually parse your'
|
||||
f' function declaration for function {func_name}.'
|
||||
)
|
||||
|
||||
|
||||
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
|
||||
if not schema.properties:
|
||||
return None
|
||||
return [
|
||||
field_name
|
||||
for field_name, field_schema in schema.properties.items()
|
||||
if not field_schema.nullable and field_schema.default is None
|
||||
]
|
||||
318
.venv/lib/python3.10/site-packages/google/genai/_common.py
Normal file
318
.venv/lib/python3.10/site-packages/google/genai/_common.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Common utilities for the SDK."""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
import typing
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
import pydantic
|
||||
from pydantic import alias_generators
|
||||
|
||||
from . import _api_client
|
||||
from . import errors
|
||||
|
||||
|
||||
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
|
||||
"""Examples:
|
||||
|
||||
set_value_by_path({}, ['a', 'b'], v)
|
||||
-> {'a': {'b': v}}
|
||||
set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
|
||||
-> {'a': {'b': [{'c': v1}, {'c': v2}]}}
|
||||
set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
|
||||
-> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
|
||||
"""
|
||||
if value is None:
|
||||
return
|
||||
for i, key in enumerate(keys[:-1]):
|
||||
if key.endswith('[]'):
|
||||
key_name = key[:-2]
|
||||
if data is not None and key_name not in data:
|
||||
if isinstance(value, list):
|
||||
data[key_name] = [{} for _ in range(len(value))]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'value {value} must be a list given an array path {key}'
|
||||
)
|
||||
if isinstance(value, list) and data is not None:
|
||||
for j, d in enumerate(data[key_name]):
|
||||
set_value_by_path(d, keys[i + 1 :], value[j])
|
||||
else:
|
||||
if data is not None:
|
||||
for d in data[key_name]:
|
||||
set_value_by_path(d, keys[i + 1 :], value)
|
||||
return
|
||||
elif key.endswith('[0]'):
|
||||
key_name = key[:-3]
|
||||
if data is not None and key_name not in data:
|
||||
data[key_name] = [{}]
|
||||
if data is not None:
|
||||
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
||||
return
|
||||
if data is not None:
|
||||
data = data.setdefault(key, {})
|
||||
|
||||
if data is not None:
|
||||
existing_data = data.get(keys[-1])
|
||||
# If there is an existing value, merge, not overwrite.
|
||||
if existing_data is not None:
|
||||
# Don't overwrite existing non-empty value with new empty value.
|
||||
# This is triggered when handling tuning datasets.
|
||||
if not value:
|
||||
pass
|
||||
# Don't fail when overwriting value with same value
|
||||
elif value == existing_data:
|
||||
pass
|
||||
# Instead of overwriting dictionary with another dictionary, merge them.
|
||||
# This is important for handling training and validation datasets in tuning.
|
||||
elif isinstance(existing_data, dict) and isinstance(value, dict):
|
||||
# Merging dictionaries. Consider deep merging in the future.
|
||||
existing_data.update(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
||||
f' Existing value: {existing_data}; New value: {value}.'
|
||||
)
|
||||
else:
|
||||
data[keys[-1]] = value
|
||||
|
||||
|
||||
def get_value_by_path(data: Any, keys: list[str]) -> Any:
|
||||
"""Examples:
|
||||
|
||||
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
||||
-> v
|
||||
get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
|
||||
-> [v1, v2]
|
||||
"""
|
||||
if keys == ['_self']:
|
||||
return data
|
||||
for i, key in enumerate(keys):
|
||||
if not data:
|
||||
return None
|
||||
if key.endswith('[]'):
|
||||
key_name = key[:-2]
|
||||
if key_name in data:
|
||||
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
|
||||
else:
|
||||
return None
|
||||
elif key.endswith('[0]'):
|
||||
key_name = key[:-3]
|
||||
if key_name in data and data[key_name]:
|
||||
return get_value_by_path(data[key_name][0], keys[i + 1 :])
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
if key in data:
|
||||
data = data[key]
|
||||
elif isinstance(data, BaseModel) and hasattr(data, key):
|
||||
data = getattr(data, key)
|
||||
else:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def convert_to_dict(obj: object) -> Any:
|
||||
"""Recursively converts a given object to a dictionary.
|
||||
|
||||
If the object is a Pydantic model, it uses the model's `model_dump()` method.
|
||||
|
||||
Args:
|
||||
obj: The object to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the object, a list of objects if a list is
|
||||
passed, or the object itself if it is not a dictionary, list, or Pydantic
|
||||
model.
|
||||
"""
|
||||
if isinstance(obj, pydantic.BaseModel):
|
||||
return obj.model_dump(exclude_none=True)
|
||||
elif isinstance(obj, dict):
|
||||
return {key: convert_to_dict(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_to_dict(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def _remove_extra_fields(
|
||||
model: Any, response: dict[str, object]
|
||||
) -> None:
|
||||
"""Removes extra fields from the response that are not in the model.
|
||||
|
||||
Mutates the response in place.
|
||||
"""
|
||||
|
||||
key_values = list(response.items())
|
||||
|
||||
for key, value in key_values:
|
||||
# Need to convert to snake case to match model fields names
|
||||
# ex: UsageMetadata
|
||||
alias_map = {
|
||||
field_info.alias: key for key, field_info in model.model_fields.items()
|
||||
}
|
||||
|
||||
if key not in model.model_fields and key not in alias_map:
|
||||
response.pop(key)
|
||||
continue
|
||||
|
||||
key = alias_map.get(key, key)
|
||||
|
||||
annotation = model.model_fields[key].annotation
|
||||
|
||||
# Get the BaseModel if Optional
|
||||
if typing.get_origin(annotation) is Union:
|
||||
annotation = typing.get_args(annotation)[0]
|
||||
|
||||
# if dict, assume BaseModel but also check that field type is not dict
|
||||
# example: FunctionCall.args
|
||||
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
|
||||
_remove_extra_fields(annotation, value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
# assume a list of dict is list of BaseModel
|
||||
if isinstance(item, dict):
|
||||
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
||||
|
||||
T = typing.TypeVar('T', bound='BaseModel')
|
||||
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
|
||||
model_config = pydantic.ConfigDict(
|
||||
alias_generator=alias_generators.to_camel,
|
||||
populate_by_name=True,
|
||||
from_attributes=True,
|
||||
protected_namespaces=(),
|
||||
extra='forbid',
|
||||
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
||||
arbitrary_types_allowed=True,
|
||||
ser_json_bytes='base64',
|
||||
val_json_bytes='base64',
|
||||
ignored_types=(typing.TypeVar,)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_response(
|
||||
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
|
||||
) -> T:
|
||||
# To maintain forward compatibility, we need to remove extra fields from
|
||||
# the response.
|
||||
# We will provide another mechanism to allow users to access these fields.
|
||||
_remove_extra_fields(cls, response)
|
||||
validated_response = cls.model_validate(response)
|
||||
return validated_response
|
||||
|
||||
def to_json_dict(self) -> dict[str, object]:
|
||||
return self.model_dump(exclude_none=True, mode='json')
|
||||
|
||||
|
||||
class CaseInSensitiveEnum(str, enum.Enum):
|
||||
"""Case insensitive enum."""
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> Any:
|
||||
try:
|
||||
return cls[value.upper()] # Try to access directly with uppercase
|
||||
except KeyError:
|
||||
try:
|
||||
return cls[value.lower()] # Try to access directly with lowercase
|
||||
except KeyError:
|
||||
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
||||
try:
|
||||
# Creating a enum instance based on the value
|
||||
# We need to use super() to avoid infinite recursion.
|
||||
unknown_enum_val = super().__new__(cls, value)
|
||||
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
||||
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
||||
return unknown_enum_val
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def timestamped_unique_name() -> str:
|
||||
"""Composes a timestamped unique name.
|
||||
|
||||
Returns:
|
||||
A string representing a unique name.
|
||||
"""
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
unique_id = uuid.uuid4().hex[0:5]
|
||||
return f'{timestamp}_{unique_id}'
|
||||
|
||||
|
||||
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
||||
"""Converts unserializable types in dict to json.dumps() compatible types.
|
||||
|
||||
This function is called in models.py after calling convert_to_dict(). The
|
||||
convert_to_dict() can convert pydantic object to dict. However, the input to
|
||||
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
|
||||
of converters). So they may be bytes in the dict and they are out of
|
||||
`ser_json_bytes` control in model_dump(mode='json') called in
|
||||
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
|
||||
|
||||
Returns:
|
||||
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
||||
to compatible type (e.g. base64 encoded string, isoformat date string).
|
||||
"""
|
||||
processed_data: dict[str, object] = {}
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
for key, value in data.items():
|
||||
if isinstance(value, bytes):
|
||||
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
|
||||
elif isinstance(value, datetime.datetime):
|
||||
processed_data[key] = value.isoformat()
|
||||
elif isinstance(value, dict):
|
||||
processed_data[key] = encode_unserializable_types(value)
|
||||
elif isinstance(value, list):
|
||||
if all(isinstance(v, bytes) for v in value):
|
||||
processed_data[key] = [
|
||||
base64.urlsafe_b64encode(v).decode('ascii') for v in value
|
||||
]
|
||||
if all(isinstance(v, datetime.datetime) for v in value):
|
||||
processed_data[key] = [v.isoformat() for v in value]
|
||||
else:
|
||||
processed_data[key] = [encode_unserializable_types(v) for v in value]
|
||||
else:
|
||||
processed_data[key] = value
|
||||
return processed_data
|
||||
|
||||
|
||||
def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Experimental warning, only warns once."""
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
warning_done = False
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal warning_done
|
||||
if not warning_done:
|
||||
warning_done = True
|
||||
warnings.warn(
|
||||
message=message,
|
||||
category=errors.ExperimentalWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
403
.venv/lib/python3.10/site-packages/google/genai/_extra_utils.py
Normal file
403
.venv/lib/python3.10/site-packages/google/genai/_extra_utils.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Extra utils depending on types that are shared between sync and async modules."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import _common
|
||||
from . import errors
|
||||
from . import types
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
||||
|
||||
logger = logging.getLogger('google_genai.models')
|
||||
|
||||
|
||||
def _create_generate_content_config_model(
|
||||
config: types.GenerateContentConfigOrDict,
|
||||
) -> types.GenerateContentConfig:
|
||||
if isinstance(config, dict):
|
||||
return types.GenerateContentConfig(**config)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def format_destination(
|
||||
src: str,
|
||||
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
||||
) -> types.CreateBatchJobConfig:
|
||||
"""Formats the destination uri based on the source uri."""
|
||||
config = (
|
||||
types._CreateBatchJobParameters(config=config).config
|
||||
or types.CreateBatchJobConfig()
|
||||
)
|
||||
|
||||
unique_name = None
|
||||
if not config.display_name:
|
||||
unique_name = _common.timestamped_unique_name()
|
||||
config.display_name = f'genai_batch_job_{unique_name}'
|
||||
|
||||
if not config.dest:
|
||||
if src.startswith('gs://') and src.endswith('.jsonl'):
|
||||
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
|
||||
# uri prefix will be "gs://bucket/path/to/src/dest".
|
||||
config.dest = f'{src[:-6]}/dest'
|
||||
elif src.startswith('bq://'):
|
||||
# If source uri is "bq://project.dataset.src", then the destination
|
||||
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
|
||||
unique_name = unique_name or _common.timestamped_unique_name()
|
||||
config.dest = f'{src}_dest_{unique_name}'
|
||||
else:
|
||||
raise ValueError(f'Unsupported source: {src}')
|
||||
return config
|
||||
|
||||
|
||||
def get_function_map(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
is_caller_method_async: bool = False,
|
||||
) -> dict[str, Callable[..., Any]]:
|
||||
"""Returns a function map from the config."""
|
||||
function_map: dict[str, Callable[..., Any]] = {}
|
||||
if not config:
|
||||
return function_map
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if config_model.tools:
|
||||
for tool in config_model.tools:
|
||||
if callable(tool):
|
||||
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
|
||||
raise errors.UnsupportedFunctionError(
|
||||
f'Function {tool.__name__} is a coroutine function, which is not'
|
||||
' supported for automatic function calling. Please manually'
|
||||
f' invoke {tool.__name__} to get the function response.'
|
||||
)
|
||||
function_map[tool.__name__] = tool
|
||||
return function_map
|
||||
|
||||
|
||||
def convert_number_values_for_dict_function_call_args(
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Converts float values in dict with no decimal to integers."""
|
||||
return {
|
||||
key: convert_number_values_for_function_call_args(value)
|
||||
for key, value in args.items()
|
||||
}
|
||||
|
||||
|
||||
def convert_number_values_for_function_call_args(
|
||||
args: Union[dict[str, object], list[object], object],
|
||||
) -> Union[dict[str, object], list[object], object]:
|
||||
"""Converts float values with no decimal to integers."""
|
||||
if isinstance(args, float) and args.is_integer():
|
||||
return int(args)
|
||||
if isinstance(args, dict):
|
||||
return {
|
||||
key: convert_number_values_for_function_call_args(value)
|
||||
for key, value in args.items()
|
||||
}
|
||||
if isinstance(args, list):
|
||||
return [
|
||||
convert_number_values_for_function_call_args(value) for value in args
|
||||
]
|
||||
return args
|
||||
|
||||
|
||||
def is_annotation_pydantic_model(annotation: Any) -> bool:
|
||||
try:
|
||||
return inspect.isclass(annotation) and issubclass(
|
||||
annotation, pydantic.BaseModel
|
||||
)
|
||||
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
|
||||
# results with versions above. for example, inspect.isclass(dict[str, int]) is
|
||||
# True in 3.10 and below but False in 3.11 and above.
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def convert_if_exist_pydantic_model(
|
||||
value: Any, annotation: Any, param_name: str, func_name: str
|
||||
) -> Any:
|
||||
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
|
||||
try:
|
||||
return annotation(**value)
|
||||
except pydantic.ValidationError as e:
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function'
|
||||
f' {func_name} from function call part because function call argument'
|
||||
f' value {value} is not compatible with parameter annotation'
|
||||
f' {annotation}, due to error {e}'
|
||||
)
|
||||
if isinstance(value, list) and get_origin(annotation) == list:
|
||||
item_type = get_args(annotation)[0]
|
||||
return [
|
||||
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
|
||||
for item in value
|
||||
]
|
||||
if isinstance(value, dict) and get_origin(annotation) == dict:
|
||||
_, value_type = get_args(annotation)
|
||||
return {
|
||||
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
|
||||
for k, v in value.items()
|
||||
}
|
||||
# example 1: typing.Union[int, float]
|
||||
# example 2: int | float equivalent to UnionType[int, float]
|
||||
if get_origin(annotation) in (Union, UnionType):
|
||||
for arg in get_args(annotation):
|
||||
if (
|
||||
(get_args(arg) and get_origin(arg) is list)
|
||||
or isinstance(value, arg)
|
||||
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
|
||||
):
|
||||
try:
|
||||
return convert_if_exist_pydantic_model(
|
||||
value, arg, param_name, func_name
|
||||
)
|
||||
# do not raise here because there could be multiple pydantic model types
|
||||
# in the union type.
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
# if none of the union type is matched, raise error
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function'
|
||||
f' {func_name} from function call part because function call argument'
|
||||
f' value {value} cannot be converted to parameter annotation'
|
||||
f' {annotation}.'
|
||||
)
|
||||
# the only exception for value and annotation type to be different is int and
|
||||
# float. see convert_number_values_for_function_call_args function for context
|
||||
if isinstance(value, int) and annotation is float:
|
||||
return value
|
||||
if not isinstance(value, annotation):
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function {func_name} from'
|
||||
f' function call part because function call argument value {value} is'
|
||||
f' not compatible with parameter annotation {annotation}.'
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def convert_argument_from_function(
|
||||
args: dict[str, Any], function: Callable[..., Any]
|
||||
) -> dict[str, Any]:
|
||||
signature = inspect.signature(function)
|
||||
func_name = function.__name__
|
||||
converted_args = {}
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param_name in args:
|
||||
converted_args[param_name] = convert_if_exist_pydantic_model(
|
||||
args[param_name],
|
||||
param.annotation,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return converted_args
|
||||
|
||||
|
||||
def invoke_function_from_dict_args(
|
||||
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
||||
) -> Any:
|
||||
converted_args = convert_argument_from_function(args, function_to_invoke)
|
||||
try:
|
||||
return function_to_invoke(**converted_args)
|
||||
except Exception as e:
|
||||
raise errors.FunctionInvocationError(
|
||||
f'Failed to invoke function {function_to_invoke.__name__} with'
|
||||
f' converted arguments {converted_args} from model returned function'
|
||||
f' call argument {args} because of error {e}'
|
||||
)
|
||||
|
||||
|
||||
async def invoke_function_from_dict_args_async(
|
||||
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
||||
) -> Any:
|
||||
converted_args = convert_argument_from_function(args, function_to_invoke)
|
||||
try:
|
||||
return await function_to_invoke(**converted_args)
|
||||
except Exception as e:
|
||||
raise errors.FunctionInvocationError(
|
||||
f'Failed to invoke function {function_to_invoke.__name__} with'
|
||||
f' converted arguments {converted_args} from model returned function'
|
||||
f' call argument {args} because of error {e}'
|
||||
)
|
||||
|
||||
|
||||
def get_function_response_parts(
|
||||
response: types.GenerateContentResponse,
|
||||
function_map: dict[str, Callable[..., Any]],
|
||||
) -> list[types.Part]:
|
||||
"""Returns the function response parts from the response."""
|
||||
func_response_parts = []
|
||||
if (
|
||||
response.candidates is not None
|
||||
and isinstance(response.candidates[0].content, types.Content)
|
||||
and response.candidates[0].content.parts is not None
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
func_name = part.function_call.name
|
||||
if func_name is not None and part.function_call.args is not None:
|
||||
func = function_map[func_name]
|
||||
args = convert_number_values_for_dict_function_call_args(
|
||||
part.function_call.args
|
||||
)
|
||||
func_response: dict[str, Any]
|
||||
try:
|
||||
func_response = {
|
||||
'result': invoke_function_from_dict_args(args, func)
|
||||
}
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
func_response = {'error': str(e)}
|
||||
func_response_part = types.Part.from_function_response(
|
||||
name=func_name, response=func_response
|
||||
)
|
||||
func_response_parts.append(func_response_part)
|
||||
return func_response_parts
|
||||
|
||||
async def get_function_response_parts_async(
|
||||
response: types.GenerateContentResponse,
|
||||
function_map: dict[str, Callable[..., Any]],
|
||||
) -> list[types.Part]:
|
||||
"""Returns the function response parts from the response."""
|
||||
func_response_parts = []
|
||||
if (
|
||||
response.candidates is not None
|
||||
and isinstance(response.candidates[0].content, types.Content)
|
||||
and response.candidates[0].content.parts is not None
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
func_name = part.function_call.name
|
||||
if func_name is not None and part.function_call.args is not None:
|
||||
func = function_map[func_name]
|
||||
args = convert_number_values_for_dict_function_call_args(
|
||||
part.function_call.args
|
||||
)
|
||||
func_response: dict[str, Any]
|
||||
try:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
func_response = {
|
||||
'result': await invoke_function_from_dict_args_async(args, func)
|
||||
}
|
||||
else:
|
||||
func_response = {
|
||||
'result': invoke_function_from_dict_args(args, func)
|
||||
}
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
func_response = {'error': str(e)}
|
||||
func_response_part = types.Part.from_function_response(
|
||||
name=func_name, response=func_response
|
||||
)
|
||||
func_response_parts.append(func_response_part)
|
||||
return func_response_parts
|
||||
|
||||
|
||||
def should_disable_afc(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> bool:
|
||||
"""Returns whether automatic function calling is enabled."""
|
||||
if not config:
|
||||
return False
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
# If max_remote_calls is less or equal to 0, warn and disable AFC.
|
||||
if (
|
||||
config_model
|
||||
and config_model.automatic_function_calling
|
||||
and config_model.automatic_function_calling.maximum_remote_calls
|
||||
is not None
|
||||
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
||||
):
|
||||
logger.warning(
|
||||
'max_remote_calls in automatic_function_calling_config'
|
||||
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
|
||||
' less than or equal to 0. Disabling automatic function calling.'
|
||||
' Please set max_remote_calls to a positive integer.'
|
||||
)
|
||||
return True
|
||||
|
||||
# Default to enable AFC if not specified.
|
||||
if (
|
||||
not config_model.automatic_function_calling
|
||||
or config_model.automatic_function_calling.disable is None
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
config_model.automatic_function_calling.disable
|
||||
and config_model.automatic_function_calling.maximum_remote_calls
|
||||
is not None
|
||||
# exclude the case where max_remote_calls is set to 10 by default.
|
||||
and 'maximum_remote_calls'
|
||||
in config_model.automatic_function_calling.model_fields_set
|
||||
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
||||
):
|
||||
logger.warning(
|
||||
'`automatic_function_calling.disable` is set to `True`. And'
|
||||
' `automatic_function_calling.maximum_remote_calls` is a'
|
||||
' positive number'
|
||||
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
|
||||
' Disabling automatic function calling. If you want to enable'
|
||||
' automatic function calling, please set'
|
||||
' `automatic_function_calling.disable` to `False` or leave it unset,'
|
||||
' and set `automatic_function_calling.maximum_remote_calls` to a'
|
||||
' positive integer or leave'
|
||||
' `automatic_function_calling.maximum_remote_calls` unset.'
|
||||
)
|
||||
|
||||
return config_model.automatic_function_calling.disable
|
||||
|
||||
|
||||
def get_max_remote_calls_afc(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> int:
|
||||
if not config:
|
||||
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
||||
"""Returns the remaining remote calls for automatic function calling."""
|
||||
if should_disable_afc(config):
|
||||
raise ValueError(
|
||||
'automatic function calling is not enabled, but SDK is trying to get'
|
||||
' max remote calls.'
|
||||
)
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if (
|
||||
not config_model.automatic_function_calling
|
||||
or config_model.automatic_function_calling.maximum_remote_calls is None
|
||||
):
|
||||
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
||||
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
||||
|
||||
|
||||
def should_append_afc_history(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> bool:
|
||||
if not config:
|
||||
return True
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if not config_model.automatic_function_calling:
|
||||
return True
|
||||
return not config_model.automatic_function_calling.ignore_call_history
|
||||
2487
.venv/lib/python3.10/site-packages/google/genai/_live_converters.py
Normal file
2487
.venv/lib/python3.10/site-packages/google/genai/_live_converters.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,597 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Replay API client."""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import datetime
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import google.auth
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from . import errors
|
||||
from ._api_client import BaseApiClient
|
||||
from ._api_client import HttpRequest
|
||||
from ._api_client import HttpResponse
|
||||
from ._common import BaseModel
|
||||
from .types import HttpOptions, HttpOptionsOrDict
|
||||
from .types import GenerateVideosOperation
|
||||
|
||||
|
||||
def _redact_version_numbers(version_string: str) -> str:
|
||||
"""Redacts version numbers in the form x.y.z from a string."""
|
||||
return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
|
||||
|
||||
|
||||
def _redact_language_label(language_label: str) -> str:
|
||||
"""Removed because replay requests are used for all languages."""
|
||||
return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label)
|
||||
|
||||
|
||||
def _redact_request_headers(headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Redacts headers that should not be recorded."""
|
||||
redacted_headers = {}
|
||||
for header_name, header_value in headers.items():
|
||||
if header_name.lower() == 'x-goog-api-key':
|
||||
redacted_headers[header_name] = '{REDACTED}'
|
||||
elif header_name.lower() == 'user-agent':
|
||||
redacted_headers[header_name] = _redact_language_label(
|
||||
_redact_version_numbers(header_value)
|
||||
)
|
||||
elif header_name.lower() == 'x-goog-api-client':
|
||||
redacted_headers[header_name] = _redact_language_label(
|
||||
_redact_version_numbers(header_value)
|
||||
)
|
||||
elif header_name.lower() == 'x-goog-user-project':
|
||||
continue
|
||||
elif header_name.lower() == 'authorization':
|
||||
continue
|
||||
else:
|
||||
redacted_headers[header_name] = header_value
|
||||
return redacted_headers
|
||||
|
||||
|
||||
def _redact_request_url(url: str) -> str:
|
||||
# Redact all the url parts before the resource name, so the test can work
|
||||
# against any project, location, version, or whether it's EasyGCP.
|
||||
result = re.sub(
|
||||
r'.*/projects/[^/]+/locations/[^/]+/',
|
||||
'{VERTEX_URL_PREFIX}/',
|
||||
url,
|
||||
)
|
||||
result = re.sub(
|
||||
r'.*-aiplatform.googleapis.com/[^/]+/',
|
||||
'{VERTEX_URL_PREFIX}/',
|
||||
result,
|
||||
)
|
||||
result = re.sub(
|
||||
r'.*aiplatform.googleapis.com/[^/]+/',
|
||||
'{VERTEX_URL_PREFIX}/',
|
||||
result,
|
||||
)
|
||||
result = re.sub(
|
||||
r'https://generativelanguage.googleapis.com/[^/]+',
|
||||
'{MLDEV_URL_PREFIX}',
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _redact_project_location_path(path: str) -> str:
|
||||
# Redact a field in the request that is known to vary based on project and
|
||||
# location.
|
||||
if 'projects/' in path and 'locations/' in path:
|
||||
result = re.sub(
|
||||
r'projects/[^/]+/locations/[^/]+/',
|
||||
'{PROJECT_AND_LOCATION_PATH}/',
|
||||
path,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return path
|
||||
|
||||
|
||||
def _redact_request_body(body: dict[str, object]) -> None:
|
||||
"""Redacts fields in the request body in place."""
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
body[key] = _redact_project_location_path(value)
|
||||
|
||||
|
||||
def redact_http_request(http_request: HttpRequest) -> None:
|
||||
http_request.headers = _redact_request_headers(http_request.headers)
|
||||
http_request.url = _redact_request_url(http_request.url)
|
||||
if not isinstance(http_request.data, bytes):
|
||||
_redact_request_body(http_request.data)
|
||||
|
||||
|
||||
def _current_file_path_and_line() -> str:
|
||||
"""Prints the current file path and line number."""
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame is not None
|
||||
and current_frame.f_back is not None
|
||||
and current_frame.f_back.f_back is not None
|
||||
):
|
||||
frame = current_frame.f_back.f_back
|
||||
filepath = inspect.getfile(frame)
|
||||
lineno = frame.f_lineno
|
||||
return f'File: {filepath}, Line: {lineno}'
|
||||
return ''
|
||||
|
||||
|
||||
def _debug_print(message: str) -> None:
|
||||
print(
|
||||
'DEBUG (test',
|
||||
os.environ.get('PYTEST_CURRENT_TEST'),
|
||||
')',
|
||||
_current_file_path_and_line(),
|
||||
':\n ',
|
||||
message,
|
||||
)
|
||||
|
||||
|
||||
class ReplayRequest(BaseModel):
|
||||
"""Represents a single request in a replay."""
|
||||
|
||||
method: str
|
||||
url: str
|
||||
headers: dict[str, str]
|
||||
body_segments: list[dict[str, object]]
|
||||
|
||||
|
||||
class ReplayResponse(BaseModel):
|
||||
"""Represents a single response in a replay."""
|
||||
|
||||
status_code: int = 200
|
||||
headers: dict[str, str]
|
||||
body_segments: list[dict[str, object]]
|
||||
byte_segments: Optional[list[bytes]] = None
|
||||
sdk_response_segments: list[dict[str, object]]
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
# Remove headers that are not deterministic so the replay files don't change
|
||||
# every time they are recorded.
|
||||
self.headers.pop('Date', None)
|
||||
self.headers.pop('Server-Timing', None)
|
||||
|
||||
|
||||
class ReplayInteraction(BaseModel):
|
||||
"""Represents a single interaction, request and response in a replay."""
|
||||
|
||||
request: ReplayRequest
|
||||
response: ReplayResponse
|
||||
|
||||
|
||||
class ReplayFile(BaseModel):
|
||||
"""Represents a recorded session."""
|
||||
|
||||
replay_id: str
|
||||
interactions: list[ReplayInteraction]
|
||||
|
||||
|
||||
class ReplayApiClient(BaseApiClient):
|
||||
"""For integration testing, send recorded response or records a response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: Literal['record', 'replay', 'auto', 'api'],
|
||||
replay_id: str,
|
||||
replays_directory: Optional[str] = None,
|
||||
vertexai: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
credentials: Optional[google.auth.credentials.Credentials] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
http_options: Optional[HttpOptions] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vertexai=vertexai,
|
||||
api_key=api_key,
|
||||
credentials=credentials,
|
||||
project=project,
|
||||
location=location,
|
||||
http_options=http_options,
|
||||
)
|
||||
self.replays_directory = replays_directory
|
||||
if not self.replays_directory:
|
||||
self.replays_directory = os.environ.get(
|
||||
'GOOGLE_GENAI_REPLAYS_DIRECTORY', None
|
||||
)
|
||||
# Valid replay modes are replay-only or record-and-replay.
|
||||
self.replay_session: Union[ReplayFile, None] = None
|
||||
self._mode = mode
|
||||
self._replay_id = replay_id
|
||||
|
||||
def initialize_replay_session(self, replay_id: str) -> None:
|
||||
self._replay_id = replay_id
|
||||
self._initialize_replay_session()
|
||||
|
||||
def _get_replay_file_path(self) -> str:
|
||||
return self._generate_file_path_from_replay_id(
|
||||
self.replays_directory, self._replay_id
|
||||
)
|
||||
|
||||
def _should_call_api(self) -> bool:
|
||||
return self._mode in ['record', 'api'] or (
|
||||
self._mode == 'auto'
|
||||
and not os.path.isfile(self._get_replay_file_path())
|
||||
)
|
||||
|
||||
def _should_update_replay(self) -> bool:
|
||||
return self._should_call_api() and self._mode != 'api'
|
||||
|
||||
def _initialize_replay_session_if_not_loaded(self) -> None:
|
||||
if not self.replay_session:
|
||||
self._initialize_replay_session()
|
||||
|
||||
def _initialize_replay_session(self) -> None:
|
||||
_debug_print('Test is using replay id: ' + self._replay_id)
|
||||
self._replay_index = 0
|
||||
self._sdk_response_index = 0
|
||||
replay_file_path = self._get_replay_file_path()
|
||||
# This should not be triggered from the constructor.
|
||||
replay_file_exists = os.path.isfile(replay_file_path)
|
||||
if self._mode == 'replay' and not replay_file_exists:
|
||||
raise ValueError(
|
||||
'Replay files do not exist for replay id: ' + self._replay_id
|
||||
)
|
||||
|
||||
if self._mode in ['replay', 'auto'] and replay_file_exists:
|
||||
with open(replay_file_path, 'r') as f:
|
||||
self.replay_session = ReplayFile.model_validate(json.loads(f.read()))
|
||||
|
||||
if self._should_update_replay():
|
||||
self.replay_session = ReplayFile(
|
||||
replay_id=self._replay_id, interactions=[]
|
||||
)
|
||||
|
||||
def _generate_file_path_from_replay_id(self, replay_directory: Optional[str], replay_id: str) -> str:
|
||||
session_parts = replay_id.split('/')
|
||||
if len(session_parts) < 3:
|
||||
raise ValueError(
|
||||
f'{replay_id}: Session ID must be in the format of'
|
||||
' module/function/[vertex|mldev]'
|
||||
)
|
||||
if replay_directory is None:
|
||||
path_parts = []
|
||||
else:
|
||||
path_parts = [replay_directory]
|
||||
path_parts.extend(session_parts)
|
||||
return os.path.join(*path_parts) + '.json'
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._should_update_replay() or not self.replay_session:
|
||||
return
|
||||
replay_file_path = self._get_replay_file_path()
|
||||
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
||||
with open(replay_file_path, 'w') as f:
|
||||
f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
|
||||
self.replay_session = None
|
||||
|
||||
def _record_interaction(
|
||||
self,
|
||||
http_request: HttpRequest,
|
||||
http_response: Union[HttpResponse, errors.APIError, bytes],
|
||||
) -> None:
|
||||
if not self._should_update_replay():
|
||||
return
|
||||
redact_http_request(http_request)
|
||||
request = ReplayRequest(
|
||||
method=http_request.method,
|
||||
url=http_request.url,
|
||||
headers=http_request.headers,
|
||||
body_segments=[http_request.data],
|
||||
)
|
||||
if isinstance(http_response, HttpResponse):
|
||||
response = ReplayResponse(
|
||||
headers=dict(http_response.headers),
|
||||
body_segments=list(http_response.segments()),
|
||||
byte_segments=[
|
||||
seg[:100] + b'...' for seg in http_response.byte_segments()
|
||||
],
|
||||
status_code=http_response.status_code,
|
||||
sdk_response_segments=[],
|
||||
)
|
||||
elif isinstance(http_response, errors.APIError):
|
||||
response = ReplayResponse(
|
||||
headers=dict(http_response.response.headers),
|
||||
body_segments=[http_response._to_replay_record()],
|
||||
status_code=http_response.code,
|
||||
sdk_response_segments=[],
|
||||
)
|
||||
elif isinstance(http_response, bytes):
|
||||
response = ReplayResponse(
|
||||
headers={},
|
||||
body_segments=[],
|
||||
byte_segments=[http_response],
|
||||
sdk_response_segments=[],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unsupported http_response type: ' + str(type(http_response))
|
||||
)
|
||||
if self.replay_session is None:
|
||||
raise ValueError('No replay session found.')
|
||||
self.replay_session.interactions.append(
|
||||
ReplayInteraction(request=request, response=response)
|
||||
)
|
||||
|
||||
def _match_request(
|
||||
self,
|
||||
http_request: HttpRequest,
|
||||
interaction: ReplayInteraction,
|
||||
) -> None:
|
||||
assert http_request.url == interaction.request.url
|
||||
assert http_request.headers == interaction.request.headers, (
|
||||
'Request headers mismatch:\n'
|
||||
f'Actual: {http_request.headers}\n'
|
||||
f'Expected: {interaction.request.headers}'
|
||||
)
|
||||
assert http_request.method == interaction.request.method
|
||||
|
||||
# Sanitize the request body, rewrite any fields that vary.
|
||||
request_data_copy = copy.deepcopy(http_request.data)
|
||||
# Both the request and recorded request must be redacted before comparing
|
||||
# so that the comparison is fair.
|
||||
if not isinstance(request_data_copy, bytes):
|
||||
_redact_request_body(request_data_copy)
|
||||
|
||||
actual_request_body = [request_data_copy]
|
||||
expected_request_body = interaction.request.body_segments
|
||||
assert actual_request_body == expected_request_body, (
|
||||
'Request body mismatch:\n'
|
||||
f'Actual: {actual_request_body}\n'
|
||||
f'Expected: {expected_request_body}'
|
||||
)
|
||||
|
||||
def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse:
|
||||
redact_http_request(http_request)
|
||||
|
||||
if self.replay_session is None:
|
||||
raise ValueError('No replay session found.')
|
||||
interaction = self.replay_session.interactions[self._replay_index]
|
||||
# Replay is on the right side of the assert so the diff makes more sense.
|
||||
self._match_request(http_request, interaction)
|
||||
self._replay_index += 1
|
||||
self._sdk_response_index = 0
|
||||
errors.APIError.raise_for_response(interaction.response)
|
||||
return HttpResponse(
|
||||
headers=interaction.response.headers,
|
||||
response_stream=[
|
||||
json.dumps(segment)
|
||||
for segment in interaction.response.body_segments
|
||||
],
|
||||
byte_stream=interaction.response.byte_segments,
|
||||
)
|
||||
|
||||
def _verify_response(self, response_model: BaseModel) -> None:
|
||||
if self._mode == 'api':
|
||||
return
|
||||
if not self.replay_session:
|
||||
raise ValueError('No replay session found.')
|
||||
# replay_index is advanced in _build_response_from_replay, so we need to -1.
|
||||
interaction = self.replay_session.interactions[self._replay_index - 1]
|
||||
if self._should_update_replay():
|
||||
if isinstance(response_model, list):
|
||||
response_model = response_model[0]
|
||||
if response_model and 'http_headers' in response_model.model_fields:
|
||||
response_model.http_headers.pop('Date', None) # type: ignore[attr-defined]
|
||||
interaction.response.sdk_response_segments.append(
|
||||
response_model.model_dump(exclude_none=True)
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(response_model, list):
|
||||
response_model = response_model[0]
|
||||
print('response_model: ', response_model.model_dump(exclude_none=True))
|
||||
if isinstance(response_model, GenerateVideosOperation):
|
||||
actual = response_model.model_dump(
|
||||
exclude={'result'}, exclude_none=True, mode='json'
|
||||
)
|
||||
else:
|
||||
actual = response_model.model_dump(exclude_none=True, mode='json')
|
||||
expected = interaction.response.sdk_response_segments[
|
||||
self._sdk_response_index
|
||||
]
|
||||
assert (
|
||||
actual == expected
|
||||
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
||||
self._sdk_response_index += 1
|
||||
|
||||
def _request(
|
||||
self,
|
||||
http_request: HttpRequest,
|
||||
stream: bool = False,
|
||||
) -> HttpResponse:
|
||||
self._initialize_replay_session_if_not_loaded()
|
||||
if self._should_call_api():
|
||||
_debug_print('api mode request: %s' % http_request)
|
||||
try:
|
||||
result = super()._request(http_request, stream)
|
||||
except errors.APIError as e:
|
||||
self._record_interaction(http_request, e)
|
||||
raise e
|
||||
if stream:
|
||||
result_segments = []
|
||||
for segment in result.segments():
|
||||
result_segments.append(json.dumps(segment))
|
||||
result = HttpResponse(result.headers, result_segments)
|
||||
self._record_interaction(http_request, result)
|
||||
# Need to return a RecordedResponse that rebuilds the response
|
||||
# segments since the stream has been consumed.
|
||||
else:
|
||||
self._record_interaction(http_request, result)
|
||||
_debug_print('api mode result: %s' % result.json)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(http_request)
|
||||
|
||||
async def _async_request(
|
||||
self,
|
||||
http_request: HttpRequest,
|
||||
stream: bool = False,
|
||||
) -> HttpResponse:
|
||||
self._initialize_replay_session_if_not_loaded()
|
||||
if self._should_call_api():
|
||||
_debug_print('api mode request: %s' % http_request)
|
||||
try:
|
||||
result = await super()._async_request(http_request, stream)
|
||||
except errors.APIError as e:
|
||||
self._record_interaction(http_request, e)
|
||||
raise e
|
||||
if stream:
|
||||
result_segments = []
|
||||
async for segment in result.async_segments():
|
||||
result_segments.append(json.dumps(segment))
|
||||
result = HttpResponse(result.headers, result_segments)
|
||||
self._record_interaction(http_request, result)
|
||||
# Need to return a RecordedResponse that rebuilds the response
|
||||
# segments since the stream has been consumed.
|
||||
else:
|
||||
self._record_interaction(http_request, result)
|
||||
_debug_print('api mode result: %s' % result.json)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(http_request)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: Union[str, io.IOBase],
|
||||
upload_url: str,
|
||||
upload_size: int,
|
||||
*,
|
||||
http_options: Optional[HttpOptionsOrDict] = None,
|
||||
) -> HttpResponse:
|
||||
if isinstance(file_path, io.IOBase):
|
||||
offset = file_path.tell()
|
||||
content = file_path.read()
|
||||
file_path.seek(offset, os.SEEK_SET)
|
||||
request = HttpRequest(
|
||||
method='POST',
|
||||
url='',
|
||||
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
||||
headers={}
|
||||
)
|
||||
else:
|
||||
request = HttpRequest(
|
||||
method='POST', url='', data={'file_path': file_path}, headers={}
|
||||
)
|
||||
if self._should_call_api():
|
||||
result: Union[str, HttpResponse]
|
||||
try:
|
||||
result = super().upload_file(
|
||||
file_path, upload_url, upload_size, http_options=http_options
|
||||
)
|
||||
except HTTPError as e:
|
||||
result = HttpResponse(
|
||||
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
||||
)
|
||||
result.status_code = e.response.status_code
|
||||
raise e
|
||||
self._record_interaction(request, result)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(request)
|
||||
|
||||
async def async_upload_file(
|
||||
self,
|
||||
file_path: Union[str, io.IOBase],
|
||||
upload_url: str,
|
||||
upload_size: int,
|
||||
*,
|
||||
http_options: Optional[HttpOptionsOrDict] = None,
|
||||
) -> HttpResponse:
|
||||
if isinstance(file_path, io.IOBase):
|
||||
offset = file_path.tell()
|
||||
content = file_path.read()
|
||||
file_path.seek(offset, os.SEEK_SET)
|
||||
request = HttpRequest(
|
||||
method='POST',
|
||||
url='',
|
||||
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
||||
headers={},
|
||||
)
|
||||
else:
|
||||
request = HttpRequest(
|
||||
method='POST', url='', data={'file_path': file_path}, headers={}
|
||||
)
|
||||
if self._should_call_api():
|
||||
result: HttpResponse
|
||||
try:
|
||||
result = await super().async_upload_file(
|
||||
file_path, upload_url, upload_size, http_options=http_options
|
||||
)
|
||||
except HTTPError as e:
|
||||
result = HttpResponse(
|
||||
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
||||
)
|
||||
result.status_code = e.response.status_code
|
||||
raise e
|
||||
self._record_interaction(request, result)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(request)
|
||||
|
||||
def download_file(
|
||||
self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
|
||||
) -> Union[HttpResponse, bytes, Any]:
|
||||
self._initialize_replay_session_if_not_loaded()
|
||||
request = self._build_request(
|
||||
'get', path=path, request_dict={}, http_options=http_options
|
||||
)
|
||||
if self._should_call_api():
|
||||
try:
|
||||
result = super().download_file(path, http_options=http_options)
|
||||
except HTTPError as e:
|
||||
result = HttpResponse(
|
||||
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
||||
)
|
||||
result.status_code = e.response.status_code
|
||||
raise e
|
||||
self._record_interaction(request, result)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(request).byte_stream[0]
|
||||
|
||||
async def async_download_file(
|
||||
self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
|
||||
) -> Any:
|
||||
self._initialize_replay_session_if_not_loaded()
|
||||
request = self._build_request(
|
||||
'get', path=path, request_dict={}, http_options=http_options
|
||||
)
|
||||
if self._should_call_api():
|
||||
try:
|
||||
result = await super().async_download_file(
|
||||
path, http_options=http_options
|
||||
)
|
||||
except HTTPError as e:
|
||||
result = HttpResponse(
|
||||
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
||||
)
|
||||
result.status_code = e.response.status_code
|
||||
raise e
|
||||
self._record_interaction(request, result)
|
||||
return result
|
||||
else:
|
||||
return self._build_response_from_replay(request).byte_stream[0]
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from .api_client import BaseApiClient
|
||||
|
||||
|
||||
@patch('genai.api_client.BaseApiClient._build_request')
|
||||
@patch('genai.api_client.BaseApiClient._request')
|
||||
def test_request_streamed_non_blocking(mock_request, mock_build_request):
|
||||
api_client = BaseApiClient(api_key='test_api_key')
|
||||
http_method = 'GET'
|
||||
path = 'test/path'
|
||||
request_dict = {'key': 'value'}
|
||||
|
||||
mock_http_request = MagicMock()
|
||||
mock_build_request.return_value = mock_http_request
|
||||
|
||||
def delayed_segments():
|
||||
chunks = ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
||||
for chunk in chunks:
|
||||
time.sleep(0.1) # 100ms delay
|
||||
yield chunk
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.segments.side_effect = delayed_segments
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
chunks = []
|
||||
start_time = time.time()
|
||||
for chunk in api_client.request_streamed(http_method, path, request_dict):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) <= 3
|
||||
end_time = time.time()
|
||||
|
||||
mock_build_request.assert_called_once_with(
|
||||
http_method, path, request_dict, None
|
||||
)
|
||||
mock_request.assert_called_once_with(mock_http_request, stream=True)
|
||||
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
||||
assert end_time - start_time > 0.3
|
||||
|
||||
|
||||
@patch('genai.api_client.BaseApiClient._build_request')
|
||||
@patch('genai.api_client.BaseApiClient._async_request')
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request(mock_async_request, mock_build_request):
|
||||
api_client = ApiClient(api_key='test_api_key')
|
||||
http_method = 'GET'
|
||||
path = 'test/path'
|
||||
request_dict = {'key': 'value'}
|
||||
|
||||
mock_http_request = MagicMock()
|
||||
mock_build_request.return_value = mock_http_request
|
||||
|
||||
class MockResponse:
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
async def delayed_response(http_request, stream):
|
||||
await asyncio.sleep(0.1) # 100ms delay
|
||||
return MockResponse('value')
|
||||
|
||||
mock_async_request.side_effect = delayed_response
|
||||
|
||||
async_coroutine1 = api_client.async_request(http_method, path, request_dict)
|
||||
async_coroutine2 = api_client.async_request(http_method, path, request_dict)
|
||||
async_coroutine3 = api_client.async_request(http_method, path, request_dict)
|
||||
|
||||
start_time = time.time()
|
||||
results = await asyncio.gather(
|
||||
async_coroutine1, async_coroutine2, async_coroutine3
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
mock_build_request.assert_called_with(http_method, path, request_dict, None)
|
||||
assert mock_build_request.call_count == 3
|
||||
mock_async_request.assert_called_with(
|
||||
http_request=mock_http_request, stream=False
|
||||
)
|
||||
assert mock_async_request.call_count == 3
|
||||
assert results == ['value', 'value', 'value']
|
||||
assert 0.1 <= end_time - start_time < 0.15
|
||||
|
||||
|
||||
@patch('genai.api_client.BaseApiClient._build_request')
|
||||
@patch('genai.api_client.BaseApiClient._async_request')
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_streamed_non_blocking(
|
||||
mock_async_request, mock_build_request
|
||||
):
|
||||
api_client = ApiClient(api_key='test_api_key')
|
||||
http_method = 'GET'
|
||||
path = 'test/path'
|
||||
request_dict = {'key': 'value'}
|
||||
|
||||
mock_http_request = MagicMock()
|
||||
mock_build_request.return_value = mock_http_request
|
||||
|
||||
class MockResponse:
|
||||
|
||||
def __init__(self, segments):
|
||||
self._segments = segments
|
||||
|
||||
# should mock async generator here but source code combines sync and async streaming in one segment method.
|
||||
# TODO: fix the above
|
||||
def segments(self):
|
||||
for segment in self._segments:
|
||||
time.sleep(0.1) # 100ms delay
|
||||
yield segment
|
||||
|
||||
async def delayed_response(http_request, stream):
|
||||
return MockResponse(['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}'])
|
||||
|
||||
mock_async_request.side_effect = delayed_response
|
||||
|
||||
chunks = []
|
||||
start_time = time.time()
|
||||
async for chunk in await api_client.async_request_streamed(
|
||||
http_method, path, request_dict
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) <= 3
|
||||
end_time = time.time()
|
||||
|
||||
mock_build_request.assert_called_once_with(
|
||||
http_method, path, request_dict, None
|
||||
)
|
||||
mock_async_request.assert_called_once_with(
|
||||
http_request=mock_http_request, stream=True
|
||||
)
|
||||
assert chunks == ['{"chunk": 1}', '{"chunk": 2}', '{"chunk": 3}']
|
||||
assert end_time - start_time > 0.3
|
||||
1113
.venv/lib/python3.10/site-packages/google/genai/_transformers.py
Normal file
1113
.venv/lib/python3.10/site-packages/google/genai/_transformers.py
Normal file
File diff suppressed because it is too large
Load Diff
1151
.venv/lib/python3.10/site-packages/google/genai/batches.py
Normal file
1151
.venv/lib/python3.10/site-packages/google/genai/batches.py
Normal file
File diff suppressed because it is too large
Load Diff
1956
.venv/lib/python3.10/site-packages/google/genai/caches.py
Normal file
1956
.venv/lib/python3.10/site-packages/google/genai/caches.py
Normal file
File diff suppressed because it is too large
Load Diff
532
.venv/lib/python3.10/site-packages/google/genai/chats.py
Normal file
532
.venv/lib/python3.10/site-packages/google/genai/chats.py
Normal file
@@ -0,0 +1,532 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections.abc import Iterator
|
||||
import sys
|
||||
from typing import AsyncIterator, Awaitable, Optional, Union, get_args
|
||||
|
||||
from . import _transformers as t
|
||||
from . import types
|
||||
from .models import AsyncModels, Models
|
||||
from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
|
||||
def _validate_content(content: Content) -> bool:
|
||||
if not content.parts:
|
||||
return False
|
||||
for part in content.parts:
|
||||
if part == Part():
|
||||
return False
|
||||
if part.text is not None and part.text == "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _validate_contents(contents: list[Content]) -> bool:
|
||||
if not contents:
|
||||
return False
|
||||
for content in contents:
|
||||
if not _validate_content(content):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _validate_response(response: GenerateContentResponse) -> bool:
|
||||
if not response.candidates:
|
||||
return False
|
||||
if not response.candidates[0].content:
|
||||
return False
|
||||
return _validate_content(response.candidates[0].content)
|
||||
|
||||
|
||||
def _extract_curated_history(
|
||||
comprehensive_history: list[Content],
|
||||
) -> list[Content]:
|
||||
"""Extracts the curated (valid) history from a comprehensive history.
|
||||
|
||||
The comprehensive history contains all turns (user input and model responses),
|
||||
including any invalid or rejected model outputs. This function filters
|
||||
that history to return only the valid turns.
|
||||
|
||||
A "turn" starts with one user input (a single content) and then follows by
|
||||
corresponding model response (which may consist of multiple contents).
|
||||
Turns are assumed to alternate: user input, model output, user input, model
|
||||
output, etc.
|
||||
|
||||
Args:
|
||||
comprehensive_history: A list representing the complete chat history.
|
||||
Including invalid turns.
|
||||
|
||||
Returns:
|
||||
curated history, which is a list of valid turns.
|
||||
"""
|
||||
if not comprehensive_history:
|
||||
return []
|
||||
curated_history = []
|
||||
length = len(comprehensive_history)
|
||||
i = 0
|
||||
current_input = comprehensive_history[i]
|
||||
if current_input.role != "user":
|
||||
raise ValueError("History must start with a user turn.")
|
||||
while i < length:
|
||||
if comprehensive_history[i].role not in ["user", "model"]:
|
||||
raise ValueError(
|
||||
f"Role must be user or model, but got {comprehensive_history[i].role}"
|
||||
)
|
||||
|
||||
if comprehensive_history[i].role == "user":
|
||||
current_input = comprehensive_history[i]
|
||||
i += 1
|
||||
else:
|
||||
current_output = []
|
||||
is_valid = True
|
||||
while i < length and comprehensive_history[i].role == "model":
|
||||
current_output.append(comprehensive_history[i])
|
||||
if is_valid and not _validate_content(comprehensive_history[i]):
|
||||
is_valid = False
|
||||
i += 1
|
||||
if is_valid:
|
||||
curated_history.append(current_input)
|
||||
curated_history.extend(current_output)
|
||||
return curated_history
|
||||
|
||||
|
||||
class _BaseChat:
|
||||
"""Base chat session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
history: list[ContentOrDict],
|
||||
):
|
||||
self._model = model
|
||||
self._config = config
|
||||
content_models = []
|
||||
for content in history:
|
||||
if not isinstance(content, Content):
|
||||
content_model = Content.model_validate(content)
|
||||
else:
|
||||
content_model = content
|
||||
content_models.append(content_model)
|
||||
self._comprehensive_history = content_models
|
||||
"""Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs.
|
||||
"""
|
||||
self._curated_history = _extract_curated_history(content_models)
|
||||
"""Curated history is the set of valid turns that will be used in the subsequent send requests.
|
||||
"""
|
||||
|
||||
def record_history(
|
||||
self,
|
||||
user_input: Content,
|
||||
model_output: list[Content],
|
||||
automatic_function_calling_history: list[Content],
|
||||
is_valid: bool,
|
||||
) -> None:
|
||||
"""Records the chat history.
|
||||
|
||||
Maintaining both comprehensive and curated histories.
|
||||
|
||||
Args:
|
||||
user_input: The user's input content.
|
||||
model_output: A list of `Content` from the model's response. This can be
|
||||
an empty list if the model produced no output.
|
||||
automatic_function_calling_history: A list of `Content` representing the
|
||||
history of automatic function calls, including the user input as the
|
||||
first entry.
|
||||
is_valid: A boolean flag indicating whether the current model output is
|
||||
considered valid.
|
||||
"""
|
||||
input_contents = (
|
||||
# Because the AFC input contains the entire curated chat history in
|
||||
# addition to the new user input, we need to truncate the AFC history
|
||||
# to deduplicate the existing chat history.
|
||||
automatic_function_calling_history[len(self._curated_history):]
|
||||
if automatic_function_calling_history
|
||||
else [user_input]
|
||||
)
|
||||
# Appends an empty content when model returns empty response, so that the
|
||||
# history is always alternating between user and model.
|
||||
output_contents = (
|
||||
model_output if model_output else [Content(role="model", parts=[])]
|
||||
)
|
||||
self._comprehensive_history.extend(input_contents)
|
||||
self._comprehensive_history.extend(output_contents)
|
||||
if is_valid:
|
||||
self._curated_history.extend(input_contents)
|
||||
self._curated_history.extend(output_contents)
|
||||
|
||||
def get_history(self, curated: bool = False) -> list[Content]:
|
||||
"""Returns the chat history.
|
||||
|
||||
Args:
|
||||
curated: A boolean flag indicating whether to return the curated (valid)
|
||||
history or the comprehensive (all turns) history. Defaults to False
|
||||
(returns the comprehensive history).
|
||||
|
||||
Returns:
|
||||
A list of `Content` objects representing the chat history.
|
||||
"""
|
||||
if curated:
|
||||
return self._curated_history
|
||||
else:
|
||||
return self._comprehensive_history
|
||||
|
||||
|
||||
def _is_part_type(
|
||||
contents: Union[list[PartUnionDict], PartUnionDict],
|
||||
) -> TypeGuard[t.ContentType]:
|
||||
if isinstance(contents, list):
|
||||
return all(_is_part_type(part) for part in contents)
|
||||
else:
|
||||
allowed_part_types = get_args(types.PartUnion)
|
||||
if type(contents) in allowed_part_types:
|
||||
return True
|
||||
else:
|
||||
# Some images don't pass isinstance(item, PIL.Image.Image)
|
||||
# For example <class 'PIL.JpegImagePlugin.JpegImageFile'>
|
||||
if types.PIL_Image is not None and isinstance(contents, types.PIL_Image):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Chat(_BaseChat):
|
||||
"""Chat session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
modules: Models,
|
||||
model: str,
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
history: list[ContentOrDict],
|
||||
):
|
||||
self._modules = modules
|
||||
super().__init__(
|
||||
model=model,
|
||||
config=config,
|
||||
history=history,
|
||||
)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message: Union[list[PartUnionDict], PartUnionDict],
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
) -> GenerateContentResponse:
|
||||
"""Sends the conversation history with the additional message and returns the model's response.
|
||||
|
||||
Args:
|
||||
message: The message to send to the model.
|
||||
config: Optional config to override the default Chat config for this
|
||||
request.
|
||||
|
||||
Returns:
|
||||
The model's response.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
chat = client.chats.create(model='gemini-1.5-flash')
|
||||
response = chat.send_message('tell me a story')
|
||||
"""
|
||||
|
||||
if not _is_part_type(message):
|
||||
raise ValueError(
|
||||
f"Message must be a valid part type: {types.PartUnion} or"
|
||||
f" {types.PartUnionDict}, got {type(message)}"
|
||||
)
|
||||
input_content = t.t_content(self._modules._api_client, message)
|
||||
response = self._modules.generate_content(
|
||||
model=self._model,
|
||||
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
||||
config=config if config else self._config,
|
||||
)
|
||||
model_output = (
|
||||
[response.candidates[0].content]
|
||||
if response.candidates and response.candidates[0].content
|
||||
else []
|
||||
)
|
||||
automatic_function_calling_history = (
|
||||
response.automatic_function_calling_history
|
||||
if response.automatic_function_calling_history
|
||||
else []
|
||||
)
|
||||
self.record_history(
|
||||
user_input=input_content,
|
||||
model_output=model_output,
|
||||
automatic_function_calling_history=automatic_function_calling_history,
|
||||
is_valid=_validate_response(response),
|
||||
)
|
||||
return response
|
||||
|
||||
def send_message_stream(
|
||||
self,
|
||||
message: Union[list[PartUnionDict], PartUnionDict],
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
) -> Iterator[GenerateContentResponse]:
|
||||
"""Sends the conversation history with the additional message and yields the model's response in chunks.
|
||||
|
||||
Args:
|
||||
message: The message to send to the model.
|
||||
config: Optional config to override the default Chat config for this
|
||||
request.
|
||||
|
||||
Yields:
|
||||
The model's response in chunks.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
chat = client.chats.create(model='gemini-1.5-flash')
|
||||
for chunk in chat.send_message_stream('tell me a story'):
|
||||
print(chunk.text)
|
||||
"""
|
||||
|
||||
if not _is_part_type(message):
|
||||
raise ValueError(
|
||||
f"Message must be a valid part type: {types.PartUnion} or"
|
||||
f" {types.PartUnionDict}, got {type(message)}"
|
||||
)
|
||||
input_content = t.t_content(self._modules._api_client, message)
|
||||
output_contents = []
|
||||
finish_reason = None
|
||||
is_valid = True
|
||||
chunk = None
|
||||
if isinstance(self._modules, Models):
|
||||
for chunk in self._modules.generate_content_stream(
|
||||
model=self._model,
|
||||
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
||||
config=config if config else self._config,
|
||||
):
|
||||
if not _validate_response(chunk):
|
||||
is_valid = False
|
||||
if chunk.candidates and chunk.candidates[0].content:
|
||||
output_contents.append(chunk.candidates[0].content)
|
||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||
finish_reason = chunk.candidates[0].finish_reason
|
||||
yield chunk
|
||||
automatic_function_calling_history = (
|
||||
chunk.automatic_function_calling_history
|
||||
if chunk.automatic_function_calling_history
|
||||
else []
|
||||
)
|
||||
self.record_history(
|
||||
user_input=input_content,
|
||||
model_output=output_contents,
|
||||
automatic_function_calling_history=automatic_function_calling_history,
|
||||
is_valid=is_valid
|
||||
and output_contents is not None
|
||||
and finish_reason is not None,
|
||||
)
|
||||
|
||||
|
||||
class Chats:
|
||||
"""A util class to create chat sessions."""
|
||||
|
||||
def __init__(self, modules: Models):
|
||||
self._modules = modules
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
history: Optional[list[ContentOrDict]] = None,
|
||||
) -> Chat:
|
||||
"""Creates a new chat session.
|
||||
|
||||
Args:
|
||||
model: The model to use for the chat.
|
||||
config: The configuration to use for the generate content request.
|
||||
history: The history to use for the chat.
|
||||
|
||||
Returns:
|
||||
A new chat session.
|
||||
"""
|
||||
return Chat(
|
||||
modules=self._modules,
|
||||
model=model,
|
||||
config=config,
|
||||
history=history if history else [],
|
||||
)
|
||||
|
||||
|
||||
class AsyncChat(_BaseChat):
|
||||
"""Async chat session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
modules: AsyncModels,
|
||||
model: str,
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
history: list[ContentOrDict],
|
||||
):
|
||||
self._modules = modules
|
||||
super().__init__(
|
||||
model=model,
|
||||
config=config,
|
||||
history=history,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Union[list[PartUnionDict], PartUnionDict],
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
) -> GenerateContentResponse:
|
||||
"""Sends the conversation history with the additional message and returns model's response.
|
||||
|
||||
Args:
|
||||
message: The message to send to the model.
|
||||
config: Optional config to override the default Chat config for this
|
||||
request.
|
||||
|
||||
Returns:
|
||||
The model's response.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
chat = client.aio.chats.create(model='gemini-1.5-flash')
|
||||
response = await chat.send_message('tell me a story')
|
||||
"""
|
||||
if not _is_part_type(message):
|
||||
raise ValueError(
|
||||
f"Message must be a valid part type: {types.PartUnion} or"
|
||||
f" {types.PartUnionDict}, got {type(message)}"
|
||||
)
|
||||
input_content = t.t_content(self._modules._api_client, message)
|
||||
response = await self._modules.generate_content(
|
||||
model=self._model,
|
||||
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
||||
config=config if config else self._config,
|
||||
)
|
||||
model_output = (
|
||||
[response.candidates[0].content]
|
||||
if response.candidates and response.candidates[0].content
|
||||
else []
|
||||
)
|
||||
automatic_function_calling_history = (
|
||||
response.automatic_function_calling_history
|
||||
if response.automatic_function_calling_history
|
||||
else []
|
||||
)
|
||||
self.record_history(
|
||||
user_input=input_content,
|
||||
model_output=model_output,
|
||||
automatic_function_calling_history=automatic_function_calling_history,
|
||||
is_valid=_validate_response(response),
|
||||
)
|
||||
return response
|
||||
|
||||
async def send_message_stream(
|
||||
self,
|
||||
message: Union[list[PartUnionDict], PartUnionDict],
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
) -> AsyncIterator[GenerateContentResponse]:
|
||||
"""Sends the conversation history with the additional message and yields the model's response in chunks.
|
||||
|
||||
Args:
|
||||
message: The message to send to the model.
|
||||
config: Optional config to override the default Chat config for this
|
||||
request.
|
||||
|
||||
Yields:
|
||||
The model's response in chunks.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
chat = client.aio.chats.create(model='gemini-1.5-flash')
|
||||
async for chunk in await chat.send_message_stream('tell me a story'):
|
||||
print(chunk.text)
|
||||
"""
|
||||
|
||||
if not _is_part_type(message):
|
||||
raise ValueError(
|
||||
f"Message must be a valid part type: {types.PartUnion} or"
|
||||
f" {types.PartUnionDict}, got {type(message)}"
|
||||
)
|
||||
input_content = t.t_content(self._modules._api_client, message)
|
||||
|
||||
async def async_generator(): # type: ignore[no-untyped-def]
|
||||
output_contents = []
|
||||
finish_reason = None
|
||||
is_valid = True
|
||||
chunk = None
|
||||
async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined]
|
||||
model=self._model,
|
||||
contents=self._curated_history + [input_content], # type: ignore[arg-type]
|
||||
config=config if config else self._config,
|
||||
):
|
||||
if not _validate_response(chunk):
|
||||
is_valid = False
|
||||
if chunk.candidates and chunk.candidates[0].content:
|
||||
output_contents.append(chunk.candidates[0].content)
|
||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||
finish_reason = chunk.candidates[0].finish_reason
|
||||
yield chunk
|
||||
|
||||
if not output_contents or finish_reason is None:
|
||||
is_valid = False
|
||||
|
||||
self.record_history(
|
||||
user_input=input_content,
|
||||
model_output=output_contents,
|
||||
automatic_function_calling_history=chunk.automatic_function_calling_history if chunk.automatic_function_calling_history else [],
|
||||
is_valid=is_valid,
|
||||
)
|
||||
return async_generator() # type: ignore[no-untyped-call, no-any-return]
|
||||
|
||||
|
||||
class AsyncChats:
|
||||
"""A util class to create async chat sessions."""
|
||||
|
||||
def __init__(self, modules: AsyncModels):
|
||||
self._modules = modules
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
config: Optional[GenerateContentConfigOrDict] = None,
|
||||
history: Optional[list[ContentOrDict]] = None,
|
||||
) -> AsyncChat:
|
||||
"""Creates a new chat session.
|
||||
|
||||
Args:
|
||||
model: The model to use for the chat.
|
||||
config: The configuration to use for the generate content request.
|
||||
history: The history to use for the chat.
|
||||
|
||||
Returns:
|
||||
A new chat session.
|
||||
"""
|
||||
return AsyncChat(
|
||||
modules=self._modules,
|
||||
model=model,
|
||||
config=config,
|
||||
history=history if history else [],
|
||||
)
|
||||
290
.venv/lib/python3.10/site-packages/google/genai/client.py
Normal file
290
.venv/lib/python3.10/site-packages/google/genai/client.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.auth
|
||||
import pydantic
|
||||
|
||||
from ._api_client import BaseApiClient
|
||||
from ._replay_api_client import ReplayApiClient
|
||||
from .batches import AsyncBatches, Batches
|
||||
from .caches import AsyncCaches, Caches
|
||||
from .chats import AsyncChats, Chats
|
||||
from .files import AsyncFiles, Files
|
||||
from .live import AsyncLive
|
||||
from .models import AsyncModels, Models
|
||||
from .operations import AsyncOperations, Operations
|
||||
from .tunings import AsyncTunings, Tunings
|
||||
from .types import HttpOptions, HttpOptionsDict
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
"""Client for making asynchronous (non-blocking) requests."""
|
||||
|
||||
def __init__(self, api_client: BaseApiClient):
|
||||
|
||||
self._api_client = api_client
|
||||
self._models = AsyncModels(self._api_client)
|
||||
self._tunings = AsyncTunings(self._api_client)
|
||||
self._caches = AsyncCaches(self._api_client)
|
||||
self._batches = AsyncBatches(self._api_client)
|
||||
self._files = AsyncFiles(self._api_client)
|
||||
self._live = AsyncLive(self._api_client)
|
||||
self._operations = AsyncOperations(self._api_client)
|
||||
|
||||
@property
|
||||
def models(self) -> AsyncModels:
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def tunings(self) -> AsyncTunings:
|
||||
return self._tunings
|
||||
|
||||
@property
|
||||
def caches(self) -> AsyncCaches:
|
||||
return self._caches
|
||||
|
||||
@property
|
||||
def batches(self) -> AsyncBatches:
|
||||
return self._batches
|
||||
|
||||
@property
|
||||
def chats(self) -> AsyncChats:
|
||||
return AsyncChats(modules=self.models)
|
||||
|
||||
@property
|
||||
def files(self) -> AsyncFiles:
|
||||
return self._files
|
||||
|
||||
@property
|
||||
def live(self) -> AsyncLive:
|
||||
return self._live
|
||||
|
||||
@property
|
||||
def operations(self) -> AsyncOperations:
|
||||
return self._operations
|
||||
|
||||
class DebugConfig(pydantic.BaseModel):
|
||||
"""Configuration options that change client network behavior when testing."""
|
||||
|
||||
client_mode: Optional[str] = pydantic.Field(
|
||||
default_factory=lambda: os.getenv('GOOGLE_GENAI_CLIENT_MODE', None)
|
||||
)
|
||||
|
||||
replays_directory: Optional[str] = pydantic.Field(
|
||||
default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAYS_DIRECTORY', None)
|
||||
)
|
||||
|
||||
replay_id: Optional[str] = pydantic.Field(
|
||||
default_factory=lambda: os.getenv('GOOGLE_GENAI_REPLAY_ID', None)
|
||||
)
|
||||
|
||||
|
||||
class Client:
|
||||
"""Client for making synchronous requests.
|
||||
|
||||
Use this client to make a request to the Gemini Developer API or Vertex AI
|
||||
API and then wait for the response.
|
||||
|
||||
To initialize the client, provide the required arguments either directly
|
||||
or by using environment variables. Gemini API users and Vertex AI users in
|
||||
express mode can provide API key by providing input argument
|
||||
`api_key="your-api-key"` or by defining `GOOGLE_API_KEY="your-api-key"` as an
|
||||
environment variable
|
||||
|
||||
Vertex AI API users can provide inputs argument as `vertexai=True,
|
||||
project="your-project-id", location="us-central1"` or by defining
|
||||
`GOOGLE_GENAI_USE_VERTEXAI=true`, `GOOGLE_CLOUD_PROJECT` and
|
||||
`GOOGLE_CLOUD_LOCATION` environment variables.
|
||||
|
||||
Attributes:
|
||||
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
|
||||
use for authentication. Applies to the Gemini Developer API only.
|
||||
vertexai: Indicates whether the client should use the Vertex AI
|
||||
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
|
||||
Applies to the Vertex AI API only.
|
||||
credentials: The credentials to use for authentication when calling the
|
||||
Vertex AI APIs. Credentials can be obtained from environment variables and
|
||||
default credentials. For more information, see
|
||||
`Set up Application Default Credentials
|
||||
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
|
||||
Applies to the Vertex AI API only.
|
||||
project: The `Google Cloud project ID <https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
|
||||
use for quota. Can be obtained from environment variables (for example,
|
||||
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
|
||||
location: The `location <https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
|
||||
to send API requests to (for example, ``us-central1``). Can be obtained
|
||||
from environment variables. Applies to the Vertex AI API only.
|
||||
debug_config: Config settings that control network behavior of the client.
|
||||
This is typically used when running test code.
|
||||
http_options: Http options to use for the client. These options will be
|
||||
applied to all requests made by the client. Example usage:
|
||||
`client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`.
|
||||
|
||||
Usage for the Gemini Developer API:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from google import genai
|
||||
|
||||
client = genai.Client(api_key='my-api-key')
|
||||
|
||||
Usage for the Vertex AI API:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from google import genai
|
||||
|
||||
client = genai.Client(
|
||||
vertexai=True, project='my-project-id', location='us-central1'
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vertexai: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
credentials: Optional[google.auth.credentials.Credentials] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
debug_config: Optional[DebugConfig] = None,
|
||||
http_options: Optional[Union[HttpOptions, HttpOptionsDict]] = None,
|
||||
):
|
||||
"""Initializes the client.
|
||||
|
||||
Args:
|
||||
vertexai (bool): Indicates whether the client should use the Vertex AI
|
||||
API endpoints. Defaults to False (uses Gemini Developer API endpoints).
|
||||
Applies to the Vertex AI API only.
|
||||
api_key (str): The `API key
|
||||
<https://ai.google.dev/gemini-api/docs/api-key>`_ to use for
|
||||
authentication. Applies to the Gemini Developer API only.
|
||||
credentials (google.auth.credentials.Credentials): The credentials to use
|
||||
for authentication when calling the Vertex AI APIs. Credentials can be
|
||||
obtained from environment variables and default credentials. For more
|
||||
information, see `Set up Application Default Credentials
|
||||
<https://cloud.google.com/docs/authentication/provide-credentials-adc>`_.
|
||||
Applies to the Vertex AI API only.
|
||||
project (str): The `Google Cloud project ID
|
||||
<https://cloud.google.com/vertex-ai/docs/start/cloud-environment>`_ to
|
||||
use for quota. Can be obtained from environment variables (for example,
|
||||
``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only.
|
||||
location (str): The `location
|
||||
<https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations>`_
|
||||
to send API requests to (for example, ``us-central1``). Can be obtained
|
||||
from environment variables. Applies to the Vertex AI API only.
|
||||
debug_config (DebugConfig): Config settings that control network behavior
|
||||
of the client. This is typically used when running test code.
|
||||
http_options (Union[HttpOptions, HttpOptionsDict]): Http options to use
|
||||
for the client.
|
||||
"""
|
||||
|
||||
self._debug_config = debug_config or DebugConfig()
|
||||
if isinstance(http_options, dict):
|
||||
http_options = HttpOptions(**http_options)
|
||||
|
||||
self._api_client = self._get_api_client(
|
||||
vertexai=vertexai,
|
||||
api_key=api_key,
|
||||
credentials=credentials,
|
||||
project=project,
|
||||
location=location,
|
||||
debug_config=self._debug_config,
|
||||
http_options=http_options,
|
||||
)
|
||||
|
||||
self._aio = AsyncClient(self._api_client)
|
||||
self._models = Models(self._api_client)
|
||||
self._tunings = Tunings(self._api_client)
|
||||
self._caches = Caches(self._api_client)
|
||||
self._batches = Batches(self._api_client)
|
||||
self._files = Files(self._api_client)
|
||||
self._operations = Operations(self._api_client)
|
||||
|
||||
@staticmethod
|
||||
def _get_api_client(
|
||||
vertexai: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
credentials: Optional[google.auth.credentials.Credentials] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
debug_config: Optional[DebugConfig] = None,
|
||||
http_options: Optional[HttpOptions] = None,
|
||||
) -> BaseApiClient:
|
||||
if debug_config and debug_config.client_mode in [
|
||||
'record',
|
||||
'replay',
|
||||
'auto',
|
||||
]:
|
||||
return ReplayApiClient(
|
||||
mode=debug_config.client_mode, # type: ignore[arg-type]
|
||||
replay_id=debug_config.replay_id, # type: ignore[arg-type]
|
||||
replays_directory=debug_config.replays_directory,
|
||||
vertexai=vertexai, # type: ignore[arg-type]
|
||||
api_key=api_key,
|
||||
credentials=credentials,
|
||||
project=project,
|
||||
location=location,
|
||||
http_options=http_options,
|
||||
)
|
||||
|
||||
return BaseApiClient(
|
||||
vertexai=vertexai,
|
||||
api_key=api_key,
|
||||
credentials=credentials,
|
||||
project=project,
|
||||
location=location,
|
||||
http_options=http_options,
|
||||
)
|
||||
|
||||
@property
|
||||
def chats(self) -> Chats:
|
||||
return Chats(modules=self.models)
|
||||
|
||||
@property
|
||||
def aio(self) -> AsyncClient:
|
||||
return self._aio
|
||||
|
||||
@property
|
||||
def models(self) -> Models:
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def tunings(self) -> Tunings:
|
||||
return self._tunings
|
||||
|
||||
@property
|
||||
def caches(self) -> Caches:
|
||||
return self._caches
|
||||
|
||||
@property
|
||||
def batches(self) -> Batches:
|
||||
return self._batches
|
||||
|
||||
@property
|
||||
def files(self) -> Files:
|
||||
return self._files
|
||||
|
||||
@property
|
||||
def operations(self) -> Operations:
|
||||
return self._operations
|
||||
|
||||
@property
|
||||
def vertexai(self) -> bool:
|
||||
"""Returns whether the client is using the Vertex AI API."""
|
||||
return self._api_client.vertexai or False
|
||||
163
.venv/lib/python3.10/site-packages/google/genai/errors.py
Normal file
163
.venv/lib/python3.10/site-packages/google/genai/errors.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Error classes for the GenAI SDK."""
|
||||
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
import httpx
|
||||
import json
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .replay_api_client import ReplayResponse
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""General errors raised by the GenAI API."""
|
||||
code: int
|
||||
response: Union['ReplayResponse', httpx.Response]
|
||||
|
||||
status: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: int,
|
||||
response_json: Any,
|
||||
response: Optional[Union['ReplayResponse', httpx.Response]] = None,
|
||||
):
|
||||
self.response = response
|
||||
self.details = response_json
|
||||
self.message = self._get_message(response_json)
|
||||
self.status = self._get_status(response_json)
|
||||
self.code = code if code else self._get_code(response_json)
|
||||
|
||||
super().__init__(f'{self.code} {self.status}. {self.details}')
|
||||
|
||||
def _get_status(self, response_json: Any) -> Any:
|
||||
return response_json.get(
|
||||
'status', response_json.get('error', {}).get('status', None)
|
||||
)
|
||||
|
||||
def _get_message(self, response_json: Any) -> Any:
|
||||
return response_json.get(
|
||||
'message', response_json.get('error', {}).get('message', None)
|
||||
)
|
||||
|
||||
def _get_code(self, response_json: Any) -> Any:
|
||||
return response_json.get(
|
||||
'code', response_json.get('error', {}).get('code', None)
|
||||
)
|
||||
|
||||
def _to_replay_record(self) -> dict[str, Any]:
|
||||
"""Returns a dictionary representation of the error for replay recording.
|
||||
|
||||
details is not included since it may expose internal information in the
|
||||
replay file.
|
||||
"""
|
||||
return {
|
||||
'error': {
|
||||
'code': self.code,
|
||||
'message': self.message,
|
||||
'status': self.status,
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def raise_for_response(
|
||||
cls, response: Union['ReplayResponse', httpx.Response]
|
||||
) -> None:
|
||||
"""Raises an error with detailed error message if the response has an error status."""
|
||||
if response.status_code == 200:
|
||||
return
|
||||
|
||||
if isinstance(response, httpx.Response):
|
||||
try:
|
||||
response.read()
|
||||
response_json = response.json()
|
||||
except json.decoder.JSONDecodeError:
|
||||
message = response.text
|
||||
response_json = {
|
||||
'message': message,
|
||||
'status': response.reason_phrase,
|
||||
}
|
||||
else:
|
||||
response_json = response.body_segments[0].get('error', {})
|
||||
|
||||
status_code = response.status_code
|
||||
if 400 <= status_code < 500:
|
||||
raise ClientError(status_code, response_json, response)
|
||||
elif 500 <= status_code < 600:
|
||||
raise ServerError(status_code, response_json, response)
|
||||
else:
|
||||
raise cls(status_code, response_json, response)
|
||||
|
||||
@classmethod
|
||||
async def raise_for_async_response(
|
||||
cls, response: Union['ReplayResponse', httpx.Response]
|
||||
) -> None:
|
||||
"""Raises an error with detailed error message if the response has an error status."""
|
||||
if response.status_code == 200:
|
||||
return
|
||||
if isinstance(response, httpx.Response):
|
||||
try:
|
||||
await response.aread()
|
||||
response_json = response.json()
|
||||
except json.decoder.JSONDecodeError:
|
||||
message = response.text
|
||||
response_json = {
|
||||
'message': message,
|
||||
'status': response.reason_phrase,
|
||||
}
|
||||
else:
|
||||
response_json = response.body_segments[0].get('error', {})
|
||||
|
||||
status_code = response.status_code
|
||||
if 400 <= status_code < 500:
|
||||
raise ClientError(status_code, response_json, response)
|
||||
elif 500 <= status_code < 600:
|
||||
raise ServerError(status_code, response_json, response)
|
||||
else:
|
||||
raise cls(status_code, response_json, response)
|
||||
|
||||
|
||||
class ClientError(APIError):
|
||||
"""Client error raised by the GenAI API."""
|
||||
pass
|
||||
|
||||
|
||||
class ServerError(APIError):
|
||||
"""Server error raised by the GenAI API."""
|
||||
pass
|
||||
|
||||
|
||||
class UnknownFunctionCallArgumentError(ValueError):
|
||||
"""Raised when the function call argument cannot be converted to the parameter annotation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedFunctionError(ValueError):
|
||||
"""Raised when the function is not supported."""
|
||||
|
||||
|
||||
class FunctionInvocationError(ValueError):
|
||||
"""Raised when the function cannot be invoked with the given arguments."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExperimentalWarning(Warning):
|
||||
"""Warning for experimental features."""
|
||||
1253
.venv/lib/python3.10/site-packages/google/genai/files.py
Normal file
1253
.venv/lib/python3.10/site-packages/google/genai/files.py
Normal file
File diff suppressed because it is too large
Load Diff
984
.venv/lib/python3.10/site-packages/google/genai/live.py
Normal file
984
.venv/lib/python3.10/site-packages/google/genai/live.py
Normal file
@@ -0,0 +1,984 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""[Preview] Live API client."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
|
||||
import warnings
|
||||
|
||||
import google.auth
|
||||
import pydantic
|
||||
from websockets import ConnectionClosed
|
||||
|
||||
from . import _api_module
|
||||
from . import _common
|
||||
from . import _transformers as t
|
||||
from . import client
|
||||
from . import types
|
||||
from ._api_client import BaseApiClient
|
||||
from ._common import get_value_by_path as getv
|
||||
from ._common import set_value_by_path as setv
|
||||
from . import _live_converters as live_converters
|
||||
from .models import _Content_to_mldev
|
||||
from .models import _Content_to_vertex
|
||||
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
from websockets.asyncio.client import connect
|
||||
except ModuleNotFoundError:
|
||||
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
||||
from websockets.client import ClientConnection # type: ignore
|
||||
from websockets.client import connect # type: ignore
|
||||
|
||||
logger = logging.getLogger('google_genai.live')
|
||||
|
||||
_FUNCTION_RESPONSE_REQUIRES_ID = (
|
||||
'FunctionResponse request must have an `id` field from the'
|
||||
' response of a ToolCall.FunctionalCalls in Google AI.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncSession:
|
||||
"""[Preview] AsyncSession."""
|
||||
|
||||
def __init__(
|
||||
self, api_client: BaseApiClient, websocket: ClientConnection
|
||||
):
|
||||
self._api_client = api_client
|
||||
self._ws = websocket
|
||||
|
||||
async def send(
|
||||
self,
|
||||
*,
|
||||
input: Optional[
|
||||
Union[
|
||||
types.ContentListUnion,
|
||||
types.ContentListUnionDict,
|
||||
types.LiveClientContentOrDict,
|
||||
types.LiveClientRealtimeInputOrDict,
|
||||
types.LiveClientToolResponseOrDict,
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
]
|
||||
] = None,
|
||||
end_of_turn: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""[Deprecated] Send input to the model.
|
||||
|
||||
> **Warning**: This method is deprecated and will be removed in a future
|
||||
version (not before Q3 2025). Please use one of the more specific methods:
|
||||
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
||||
instead.
|
||||
|
||||
The method will send the input request to the server.
|
||||
|
||||
Args:
|
||||
input: The input request to the model.
|
||||
end_of_turn: Whether the input is the last message in a turn.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
|
||||
async with client.aio.live.connect(model='...') as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
warnings.warn(
|
||||
'The `session.send` method is deprecated and will be removed in a '
|
||||
'future version (not before Q3 2025).\n'
|
||||
'Please use one of the more specific methods: `send_client_content`, '
|
||||
'`send_realtime_input`, or `send_tool_response` instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_message = self._parse_client_message(input, end_of_turn)
|
||||
await self._ws.send(json.dumps(client_message))
|
||||
|
||||
async def send_client_content(
|
||||
self,
|
||||
*,
|
||||
turns: Optional[
|
||||
Union[
|
||||
types.Content,
|
||||
types.ContentDict,
|
||||
list[Union[types.Content, types.ContentDict]]
|
||||
]
|
||||
] = None,
|
||||
turn_complete: bool = True,
|
||||
) -> None:
|
||||
"""Send non-realtime, turn based content to the model.
|
||||
|
||||
There are two ways to send messages to the live API:
|
||||
`send_client_content` and `send_realtime_input`.
|
||||
|
||||
`send_client_content` messages are added to the model context **in order**.
|
||||
Having a conversation using `send_client_content` messages is roughly
|
||||
equivalent to using the `Chat.send_message_stream` method, except that the
|
||||
state of the `chat` history is stored on the API server.
|
||||
|
||||
Because of `send_client_content`'s order guarantee, the model cannot
|
||||
respond as quickly to `send_client_content` messages as to
|
||||
`send_realtime_input` messages. This makes the biggest difference when
|
||||
sending objects that have significant preprocessing time (typically images).
|
||||
|
||||
The `send_client_content` message sends a list of `Content` objects,
|
||||
which has more options than the `media:Blob` sent by `send_realtime_input`.
|
||||
|
||||
The main use-cases for `send_client_content` over `send_realtime_input` are:
|
||||
|
||||
- Prefilling a conversation context (including sending anything that can't
|
||||
be represented as a realtime message), before starting a realtime
|
||||
conversation.
|
||||
- Conducting a non-realtime conversation, similar to `client.chat`, using
|
||||
the live api.
|
||||
|
||||
Caution: Interleaving `send_client_content` and `send_realtime_input`
|
||||
in the same conversation is not recommended and can lead to unexpected
|
||||
results.
|
||||
|
||||
Args:
|
||||
turns: A `Content` object or list of `Content` objects (or equivalent
|
||||
dicts).
|
||||
turn_complete: if true (the default) the model will reply immediately. If
|
||||
false, the model will wait for you to send additional client_content,
|
||||
and will not return until you send `turn_complete=True`.
|
||||
|
||||
Example:
|
||||
```
|
||||
import google.genai
|
||||
from google.genai import types
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
client = genai.Client()
|
||||
async with client.aio.live.connect(
|
||||
model=MODEL_NAME,
|
||||
config={"response_modalities": ["TEXT"]}
|
||||
) as session:
|
||||
await session.send_client_content(
|
||||
turns=types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(text="Hello world!")]))
|
||||
async for msg in session.receive():
|
||||
if msg.text:
|
||||
print(msg.text)
|
||||
```
|
||||
"""
|
||||
client_content = t.t_client_content(turns, turn_complete)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
client_content_dict = live_converters._LiveClientContent_to_vertex(
|
||||
api_client=self._api_client, from_object=client_content
|
||||
)
|
||||
else:
|
||||
client_content_dict = live_converters._LiveClientContent_to_mldev(
|
||||
api_client=self._api_client, from_object=client_content
|
||||
)
|
||||
|
||||
await self._ws.send(json.dumps({'client_content': client_content_dict}))
|
||||
|
||||
async def send_realtime_input(
|
||||
self,
|
||||
*,
|
||||
media: Optional[types.BlobImageUnionDict] = None,
|
||||
audio: Optional[types.BlobOrDict] = None,
|
||||
audio_stream_end: Optional[bool] = None,
|
||||
video: Optional[types.BlobImageUnionDict] = None,
|
||||
text: Optional[str] = None,
|
||||
activity_start: Optional[types.ActivityStartOrDict] = None,
|
||||
activity_end: Optional[types.ActivityEndOrDict] = None,
|
||||
) -> None:
|
||||
"""Send realtime input to the model, only send one argument per call.
|
||||
|
||||
Use `send_realtime_input` for realtime audio chunks and video
|
||||
frames(images).
|
||||
|
||||
With `send_realtime_input` the api will respond to audio automatically
|
||||
based on voice activity detection (VAD).
|
||||
|
||||
`send_realtime_input` is optimized for responsivness at the expense of
|
||||
deterministic ordering. Audio and video tokens are added to the
|
||||
context when they become available.
|
||||
|
||||
Args:
|
||||
media: A `Blob`-like object, the realtime media to send.
|
||||
|
||||
Example:
|
||||
```
|
||||
from pathlib import Path
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
import PIL.Image
|
||||
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
async with client.aio.live.connect(
|
||||
model=MODEL_NAME,
|
||||
config={"response_modalities": ["TEXT"]},
|
||||
) as session:
|
||||
await session.send_realtime_input(
|
||||
media=PIL.Image.open('image.jpg'))
|
||||
|
||||
audio_bytes = Path('audio.pcm').read_bytes()
|
||||
await session.send_realtime_input(
|
||||
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
|
||||
|
||||
async for msg in session.receive():
|
||||
if msg.text is not None:
|
||||
print(f'{msg.text}')
|
||||
```
|
||||
"""
|
||||
kwargs:dict[str, Any] = {}
|
||||
if media is not None:
|
||||
kwargs['media'] = media
|
||||
if audio is not None:
|
||||
kwargs['audio'] = audio
|
||||
if audio_stream_end is not None:
|
||||
kwargs['audio_stream_end'] = audio_stream_end
|
||||
if video is not None:
|
||||
kwargs['video'] = video
|
||||
if text is not None:
|
||||
kwargs['text'] = text
|
||||
if activity_start is not None:
|
||||
kwargs['activity_start'] = activity_start
|
||||
if activity_end is not None:
|
||||
kwargs['activity_end'] = activity_end
|
||||
|
||||
if len(kwargs) != 1:
|
||||
raise ValueError(
|
||||
f'Only one argument can be set, got {len(kwargs)}:'
|
||||
f' {list(kwargs.keys())}'
|
||||
)
|
||||
realtime_input = types.LiveSendRealtimeInputParameters.model_validate(
|
||||
kwargs
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
realtime_input_dict = (
|
||||
live_converters._LiveSendRealtimeInputParameters_to_vertex(
|
||||
api_client=self._api_client, from_object=realtime_input
|
||||
)
|
||||
)
|
||||
else:
|
||||
realtime_input_dict = (
|
||||
live_converters._LiveSendRealtimeInputParameters_to_mldev(
|
||||
api_client=self._api_client, from_object=realtime_input
|
||||
)
|
||||
)
|
||||
realtime_input_dict = _common.convert_to_dict(realtime_input_dict)
|
||||
realtime_input_dict = _common.encode_unserializable_types(
|
||||
realtime_input_dict
|
||||
)
|
||||
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
|
||||
|
||||
async def send_tool_response(
|
||||
self,
|
||||
*,
|
||||
function_responses: Union[
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
],
|
||||
) -> None:
|
||||
"""Send a tool response to the session.
|
||||
|
||||
Use `send_tool_response` to reply to `LiveServerToolCall` messages
|
||||
from the server.
|
||||
|
||||
To set the available tools, use the `config.tools` argument
|
||||
when you connect to the session (`client.live.connect`).
|
||||
|
||||
Args:
|
||||
function_responses: A `FunctionResponse`-like object or list of
|
||||
`FunctionResponse`-like objects.
|
||||
|
||||
Example:
|
||||
```
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
|
||||
config = {
|
||||
"tools": tools,
|
||||
"response_modalities": ['TEXT']
|
||||
}
|
||||
|
||||
async with client.aio.live.connect(
|
||||
model='models/gemini-2.0-flash-live-001',
|
||||
config=config
|
||||
) as session:
|
||||
prompt = "Turn on the lights please"
|
||||
await session.send_client_content(
|
||||
turns={"parts": [{'text': prompt}]}
|
||||
)
|
||||
|
||||
async for chunk in session.receive():
|
||||
if chunk.server_content:
|
||||
if chunk.text is not None:
|
||||
print(chunk.text)
|
||||
elif chunk.tool_call:
|
||||
print(chunk.tool_call)
|
||||
print('_'*80)
|
||||
function_response=types.FunctionResponse(
|
||||
name='turn_on_the_lights',
|
||||
response={'result': 'ok'},
|
||||
id=chunk.tool_call.function_calls[0].id,
|
||||
)
|
||||
print(function_response)
|
||||
await session.send_tool_response(
|
||||
function_responses=function_response
|
||||
)
|
||||
|
||||
print('_'*80)
|
||||
"""
|
||||
tool_response = t.t_tool_response(function_responses)
|
||||
if self._api_client.vertexai:
|
||||
tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
|
||||
api_client=self._api_client, from_object=tool_response
|
||||
)
|
||||
else:
|
||||
tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
|
||||
api_client=self._api_client, from_object=tool_response
|
||||
)
|
||||
for response in tool_response_dict.get('functionResponses', []):
|
||||
if response.get('id') is None:
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
|
||||
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
|
||||
|
||||
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
|
||||
"""Receive model responses from the server.
|
||||
|
||||
The method will yield the model responses from the server. The returned
|
||||
responses will represent a complete model turn. When the returned message
|
||||
is function call, user must call `send` with the function response to
|
||||
continue the turn.
|
||||
|
||||
Yields:
|
||||
The model responses from the server.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
|
||||
async with client.aio.live.connect(model='...') as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
# TODO(b/365983264) Handle intermittent issues for the user.
|
||||
while result := await self._receive():
|
||||
if result.server_content and result.server_content.turn_complete:
|
||||
yield result
|
||||
break
|
||||
yield result
|
||||
|
||||
async def start_stream(
|
||||
self, *, stream: AsyncIterator[bytes], mime_type: str
|
||||
) -> AsyncIterator[types.LiveServerMessage]:
|
||||
"""[Deprecated] Start a live session from a data stream.
|
||||
|
||||
> **Warning**: This method is deprecated and will be removed in a future
|
||||
version (not before Q2 2025). Please use one of the more specific methods:
|
||||
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
||||
instead.
|
||||
|
||||
The interaction terminates when the input stream is complete.
|
||||
This method will start two async tasks. One task will be used to send the
|
||||
input stream to the model and the other task will be used to receive the
|
||||
responses from the model.
|
||||
|
||||
Args:
|
||||
stream: An iterator that yields the model response.
|
||||
mime_type: The MIME type of the data in the stream.
|
||||
|
||||
Yields:
|
||||
The audio bytes received from the model and server response messages.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
config = {'response_modalities': ['AUDIO']}
|
||||
async def audio_stream():
|
||||
stream = read_audio()
|
||||
for data in stream:
|
||||
yield data
|
||||
async with client.aio.live.connect(model='...', config=config) as session:
|
||||
for audio in session.start_stream(stream = audio_stream(),
|
||||
mime_type = 'audio/pcm'):
|
||||
play_audio_chunk(audio.data)
|
||||
"""
|
||||
warnings.warn(
|
||||
'Setting `AsyncSession.start_stream` is deprecated, '
|
||||
'and will be removed in a future release (not before Q3 2025). '
|
||||
'Please use the `receive`, and `send_realtime_input`, methods instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
stop_event = asyncio.Event()
|
||||
# Start the send loop. When stream is complete stop_event is set.
|
||||
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
|
||||
recv_task = None
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
recv_task = asyncio.create_task(self._receive())
|
||||
await asyncio.wait(
|
||||
[
|
||||
recv_task,
|
||||
asyncio.create_task(stop_event.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if recv_task.done():
|
||||
yield recv_task.result()
|
||||
# Give a chance for the send loop to process requests.
|
||||
await asyncio.sleep(10**-12)
|
||||
except ConnectionClosed:
|
||||
break
|
||||
if recv_task is not None and not recv_task.done():
|
||||
recv_task.cancel()
|
||||
# Wait for the task to finish (cancelled or not)
|
||||
try:
|
||||
await recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _receive(self) -> types.LiveServerMessage:
|
||||
parameter_model = types.LiveServerMessage()
|
||||
try:
|
||||
raw_response = await self._ws.recv(decode=False)
|
||||
except TypeError:
|
||||
raw_response = await self._ws.recv() # type: ignore[assignment]
|
||||
if raw_response:
|
||||
try:
|
||||
response = json.loads(raw_response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
||||
else:
|
||||
response = {}
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
|
||||
else:
|
||||
response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
|
||||
|
||||
return types.LiveServerMessage._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
|
||||
async def _send_loop(
|
||||
self,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
mime_type: str,
|
||||
stop_event: asyncio.Event,
|
||||
) -> None:
|
||||
async for data in data_stream:
|
||||
model_input = types.LiveClientRealtimeInput(
|
||||
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
|
||||
)
|
||||
await self.send(input=model_input)
|
||||
# Give a chance for the receive loop to process responses.
|
||||
await asyncio.sleep(10**-12)
|
||||
# Give a chance for the receiver to process the last response.
|
||||
stop_event.set()
|
||||
|
||||
def _parse_client_message(
|
||||
self,
|
||||
input: Optional[
|
||||
Union[
|
||||
types.ContentListUnion,
|
||||
types.ContentListUnionDict,
|
||||
types.LiveClientContentOrDict,
|
||||
types.LiveClientRealtimeInputOrDict,
|
||||
types.LiveClientToolResponseOrDict,
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
]
|
||||
] = None,
|
||||
end_of_turn: Optional[bool] = False,
|
||||
) -> types.LiveClientMessageDict:
|
||||
|
||||
formatted_input: Any = input
|
||||
|
||||
if not input:
|
||||
logging.info('No input provided. Assume it is the end of turn.')
|
||||
return {'client_content': {'turn_complete': True}}
|
||||
if isinstance(input, str):
|
||||
formatted_input = [input]
|
||||
elif isinstance(input, dict) and 'data' in input:
|
||||
try:
|
||||
blob_input = types.Blob(**input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
if (
|
||||
isinstance(blob_input, types.Blob)
|
||||
and isinstance(blob_input.data, bytes)
|
||||
):
|
||||
formatted_input = [
|
||||
blob_input.model_dump(mode='json', exclude_none=True)
|
||||
]
|
||||
elif isinstance(input, types.Blob):
|
||||
formatted_input = [input]
|
||||
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
||||
# ToolResponse.FunctionResponse
|
||||
if not (self._api_client.vertexai) and 'id' not in input:
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
formatted_input = [input]
|
||||
|
||||
if isinstance(formatted_input, Sequence) and any(
|
||||
isinstance(c, dict) and 'name' in c and 'response' in c
|
||||
for c in formatted_input
|
||||
):
|
||||
# ToolResponse.FunctionResponse
|
||||
function_responses_input = []
|
||||
for item in formatted_input:
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
function_response_input = types.FunctionResponse(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
if (
|
||||
function_response_input.id is None
|
||||
and not self._api_client.vertexai
|
||||
):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
else:
|
||||
function_response_dict = function_response_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get(
|
||||
'id'
|
||||
)
|
||||
function_responses_input.append(function_response_typeddict)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=function_responses_input
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence) and any(
|
||||
isinstance(c, str) for c in formatted_input
|
||||
):
|
||||
to_object: dict[str, Any] = {}
|
||||
content_input_parts: list[types.PartUnion] = []
|
||||
for item in formatted_input:
|
||||
if isinstance(item, get_args(types.PartUnion)):
|
||||
content_input_parts.append(item)
|
||||
if self._api_client.vertexai:
|
||||
contents = [
|
||||
_Content_to_vertex(self._api_client, item, to_object)
|
||||
for item in t.t_contents(self._api_client, content_input_parts)
|
||||
]
|
||||
else:
|
||||
contents = [
|
||||
_Content_to_mldev(self._api_client, item, to_object)
|
||||
for item in t.t_contents(self._api_client, content_input_parts)
|
||||
]
|
||||
|
||||
content_dict_list: list[types.ContentDict] = []
|
||||
for item in contents:
|
||||
try:
|
||||
content_input = types.Content(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
content_dict_list.append(
|
||||
types.ContentDict(
|
||||
parts=content_input.model_dump(exclude_none=True, mode='json')[
|
||||
'parts'
|
||||
],
|
||||
role=content_input.role,
|
||||
)
|
||||
)
|
||||
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=content_dict_list, turn_complete=end_of_turn
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence):
|
||||
if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
|
||||
pass
|
||||
elif any(isinstance(b, types.Blob) for b in formatted_input):
|
||||
formatted_input = [
|
||||
b.model_dump(exclude_none=True, mode='json')
|
||||
for b in formatted_input
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=formatted_input
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(formatted_input, dict):
|
||||
if 'content' in formatted_input or 'turns' in formatted_input:
|
||||
# TODO(b/365983264) Add validation checks for content_update input_dict.
|
||||
if 'turns' in formatted_input:
|
||||
content_turns = formatted_input['turns']
|
||||
else:
|
||||
content_turns = formatted_input['content']
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=content_turns,
|
||||
turn_complete=formatted_input.get('turn_complete'),
|
||||
)
|
||||
)
|
||||
elif 'media_chunks' in formatted_input:
|
||||
try:
|
||||
realtime_input = types.LiveClientRealtimeInput(**formatted_input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=realtime_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)['media_chunks']
|
||||
)
|
||||
)
|
||||
elif 'function_responses' in formatted_input:
|
||||
try:
|
||||
tool_response_input = types.LiveClientToolResponse(**formatted_input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=tool_response_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)['function_responses']
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
elif isinstance(formatted_input, types.LiveClientRealtimeInput):
|
||||
realtime_input_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=realtime_input_dict.get('media_chunks')
|
||||
)
|
||||
)
|
||||
if (
|
||||
client_message['realtime_input'] is not None
|
||||
and client_message['realtime_input']['media_chunks'] is not None
|
||||
and isinstance(
|
||||
client_message['realtime_input']['media_chunks'][0]['data'], bytes
|
||||
)
|
||||
):
|
||||
formatted_media_chunks: list[types.BlobDict] = []
|
||||
for item in client_message['realtime_input']['media_chunks']:
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
blob_input = types.Blob(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
if (
|
||||
isinstance(blob_input, types.Blob)
|
||||
and isinstance(blob_input.data, bytes)
|
||||
and blob_input.data is not None
|
||||
):
|
||||
formatted_media_chunks.append(
|
||||
types.BlobDict(
|
||||
data=base64.b64decode(blob_input.data),
|
||||
mime_type=blob_input.mime_type,
|
||||
)
|
||||
)
|
||||
|
||||
client_message['realtime_input'][
|
||||
'media_chunks'
|
||||
] = formatted_media_chunks
|
||||
|
||||
elif isinstance(formatted_input, types.LiveClientContent):
|
||||
client_content_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=client_content_dict.get('turns'),
|
||||
turn_complete=client_content_dict.get('turn_complete'),
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, types.LiveClientToolResponse):
|
||||
# ToolResponse.FunctionResponse
|
||||
if (
|
||||
not (self._api_client.vertexai)
|
||||
and formatted_input.function_responses is not None
|
||||
and not (formatted_input.function_responses[0].id)
|
||||
):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
).get('function_responses')
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, types.FunctionResponse):
|
||||
if not (self._api_client.vertexai) and not (formatted_input.id):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
function_response_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get('id')
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=[function_response_typeddict]
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence) and isinstance(
|
||||
formatted_input[0], types.FunctionResponse
|
||||
):
|
||||
if not (self._api_client.vertexai) and not (formatted_input[0].id):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
function_response_list: list[types.FunctionResponseDict] = []
|
||||
for item in formatted_input:
|
||||
function_response_dict = item.model_dump(exclude_none=True, mode='json')
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get('id')
|
||||
function_response_list.append(function_response_typeddict)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=function_response_list
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
|
||||
return client_message
|
||||
|
||||
async def close(self) -> None:
|
||||
# Close the websocket connection.
|
||||
await self._ws.close()
|
||||
|
||||
|
||||
class AsyncLive(_api_module.BaseModule):
|
||||
"""[Preview] AsyncLive."""
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
config: Optional[types.LiveConnectConfigOrDict] = None,
|
||||
) -> AsyncIterator[AsyncSession]:
|
||||
"""[Preview] Connect to the live server.
|
||||
|
||||
Note: the live API is currently in preview.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
config = {}
|
||||
async with client.aio.live.connect(model='...', config=config) as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
base_url = self._api_client._websocket_base_url()
|
||||
if isinstance(base_url, bytes):
|
||||
base_url = base_url.decode('utf-8')
|
||||
transformed_model = t.t_model(self._api_client, model)
|
||||
|
||||
parameter_model = _t_live_connect_config(self._api_client, config)
|
||||
|
||||
if self._api_client.api_key:
|
||||
api_key = self._api_client.api_key
|
||||
version = self._api_client._http_options.api_version
|
||||
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
||||
headers = self._api_client._http_options.headers
|
||||
|
||||
request_dict = _common.convert_to_dict(
|
||||
live_converters._LiveConnectParameters_to_mldev(
|
||||
api_client=self._api_client,
|
||||
from_object=types.LiveConnectParameters(
|
||||
model=transformed_model,
|
||||
config=parameter_model,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
del request_dict['config']
|
||||
|
||||
setv(request_dict, ['setup', 'model'], transformed_model)
|
||||
|
||||
request = json.dumps(request_dict)
|
||||
else:
|
||||
# Get bearer token through Application Default Credentials.
|
||||
creds, _ = google.auth.default( # type: ignore[no-untyped-call]
|
||||
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
||||
)
|
||||
|
||||
# creds.valid is False, and creds.token is None
|
||||
# Need to refresh credentials to populate those
|
||||
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
|
||||
creds.refresh(auth_req)
|
||||
bearer_token = creds.token
|
||||
headers = self._api_client._http_options.headers
|
||||
if headers is not None:
|
||||
headers.update({
|
||||
'Authorization': 'Bearer {}'.format(bearer_token),
|
||||
})
|
||||
version = self._api_client._http_options.api_version
|
||||
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
||||
location = self._api_client.location
|
||||
project = self._api_client.project
|
||||
if transformed_model.startswith('publishers/'):
|
||||
transformed_model = (
|
||||
f'projects/{project}/locations/{location}/' + transformed_model
|
||||
)
|
||||
request_dict = _common.convert_to_dict(
|
||||
live_converters._LiveConnectParameters_to_vertex(
|
||||
api_client=self._api_client,
|
||||
from_object=types.LiveConnectParameters(
|
||||
model=transformed_model,
|
||||
config=parameter_model,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
del request_dict['config']
|
||||
|
||||
if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
|
||||
setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
|
||||
|
||||
request = json.dumps(request_dict)
|
||||
|
||||
try:
|
||||
async with connect(uri, additional_headers=headers) as ws:
|
||||
await ws.send(request)
|
||||
logger.info(await ws.recv(decode=False))
|
||||
|
||||
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
||||
except TypeError:
|
||||
# Try with the older websockets API
|
||||
async with connect(uri, extra_headers=headers) as ws:
|
||||
await ws.send(request)
|
||||
logger.info(await ws.recv())
|
||||
|
||||
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
||||
|
||||
|
||||
def _t_live_connect_config(
|
||||
api_client: BaseApiClient,
|
||||
config: Optional[types.LiveConnectConfigOrDict],
|
||||
) -> types.LiveConnectConfig:
|
||||
# Ensure the config is a LiveConnectConfig.
|
||||
if config is None:
|
||||
parameter_model = types.LiveConnectConfig()
|
||||
elif isinstance(config, dict):
|
||||
if getv(config, ['system_instruction']) is not None:
|
||||
converted_system_instruction = t.t_content(
|
||||
api_client, getv(config, ['system_instruction'])
|
||||
)
|
||||
else:
|
||||
converted_system_instruction = None
|
||||
parameter_model = types.LiveConnectConfig(**config)
|
||||
parameter_model.system_instruction = converted_system_instruction
|
||||
else:
|
||||
if config.system_instruction is None:
|
||||
system_instruction = None
|
||||
else:
|
||||
system_instruction = t.t_content(
|
||||
api_client, getv(config, ['system_instruction'])
|
||||
)
|
||||
parameter_model = config
|
||||
parameter_model.system_instruction = system_instruction
|
||||
|
||||
if parameter_model.generation_config is not None:
|
||||
warnings.warn(
|
||||
'Setting `LiveConnectConfig.generation_config` is deprecated, '
|
||||
'please set the fields on `LiveConnectConfig` directly. This will '
|
||||
'become an error in a future version (not before Q3 2025)',
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return parameter_model
|
||||
6927
.venv/lib/python3.10/site-packages/google/genai/models.py
Normal file
6927
.venv/lib/python3.10/site-packages/google/genai/models.py
Normal file
File diff suppressed because it is too large
Load Diff
651
.venv/lib/python3.10/site-packages/google/genai/operations.py
Normal file
651
.venv/lib/python3.10/site-packages/google/genai/operations.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
from . import _api_module
|
||||
from . import _common
|
||||
from . import _transformers as t
|
||||
from . import types
|
||||
from ._api_client import BaseApiClient
|
||||
from ._common import get_value_by_path as getv
|
||||
from ._common import set_value_by_path as setv
|
||||
|
||||
logger = logging.getLogger('google_genai.operations')
|
||||
|
||||
|
||||
def _GetOperationParameters_to_mldev(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['operation_name']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['_url', 'operationName'],
|
||||
getv(from_object, ['operation_name']),
|
||||
)
|
||||
|
||||
if getv(from_object, ['config']) is not None:
|
||||
setv(to_object, ['config'], getv(from_object, ['config']))
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GetOperationParameters_to_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['operation_name']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['_url', 'operationName'],
|
||||
getv(from_object, ['operation_name']),
|
||||
)
|
||||
|
||||
if getv(from_object, ['config']) is not None:
|
||||
setv(to_object, ['config'], getv(from_object, ['config']))
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _FetchPredictOperationParameters_to_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['operation_name']) is not None:
|
||||
setv(to_object, ['operationName'], getv(from_object, ['operation_name']))
|
||||
|
||||
if getv(from_object, ['resource_name']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['_url', 'resourceName'],
|
||||
getv(from_object, ['resource_name']),
|
||||
)
|
||||
|
||||
if getv(from_object, ['config']) is not None:
|
||||
setv(to_object, ['config'], getv(from_object, ['config']))
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _Video_from_mldev(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['video', 'uri']) is not None:
|
||||
setv(to_object, ['uri'], getv(from_object, ['video', 'uri']))
|
||||
|
||||
if getv(from_object, ['video', 'encodedVideo']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['video_bytes'],
|
||||
t.t_bytes(api_client, getv(from_object, ['video', 'encodedVideo'])),
|
||||
)
|
||||
|
||||
if getv(from_object, ['encoding']) is not None:
|
||||
setv(to_object, ['mime_type'], getv(from_object, ['encoding']))
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GeneratedVideo_from_mldev(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['_self']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['video'],
|
||||
_Video_from_mldev(api_client, getv(from_object, ['_self']), to_object),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GenerateVideosResponse_from_mldev(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['generatedSamples']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['generated_videos'],
|
||||
[
|
||||
_GeneratedVideo_from_mldev(api_client, item, to_object)
|
||||
for item in getv(from_object, ['generatedSamples'])
|
||||
],
|
||||
)
|
||||
|
||||
if getv(from_object, ['raiMediaFilteredCount']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['rai_media_filtered_count'],
|
||||
getv(from_object, ['raiMediaFilteredCount']),
|
||||
)
|
||||
|
||||
if getv(from_object, ['raiMediaFilteredReasons']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['rai_media_filtered_reasons'],
|
||||
getv(from_object, ['raiMediaFilteredReasons']),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GenerateVideosOperation_from_mldev(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['name']) is not None:
|
||||
setv(to_object, ['name'], getv(from_object, ['name']))
|
||||
|
||||
if getv(from_object, ['metadata']) is not None:
|
||||
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
|
||||
|
||||
if getv(from_object, ['done']) is not None:
|
||||
setv(to_object, ['done'], getv(from_object, ['done']))
|
||||
|
||||
if getv(from_object, ['error']) is not None:
|
||||
setv(to_object, ['error'], getv(from_object, ['error']))
|
||||
|
||||
if getv(from_object, ['response', 'generateVideoResponse']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['response'],
|
||||
_GenerateVideosResponse_from_mldev(
|
||||
api_client,
|
||||
getv(from_object, ['response', 'generateVideoResponse']),
|
||||
to_object,
|
||||
),
|
||||
)
|
||||
|
||||
if getv(from_object, ['response', 'generateVideoResponse']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['result'],
|
||||
_GenerateVideosResponse_from_mldev(
|
||||
api_client,
|
||||
getv(from_object, ['response', 'generateVideoResponse']),
|
||||
to_object,
|
||||
),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _Video_from_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['gcsUri']) is not None:
|
||||
setv(to_object, ['uri'], getv(from_object, ['gcsUri']))
|
||||
|
||||
if getv(from_object, ['bytesBase64Encoded']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['video_bytes'],
|
||||
t.t_bytes(api_client, getv(from_object, ['bytesBase64Encoded'])),
|
||||
)
|
||||
|
||||
if getv(from_object, ['mimeType']) is not None:
|
||||
setv(to_object, ['mime_type'], getv(from_object, ['mimeType']))
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GeneratedVideo_from_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['_self']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['video'],
|
||||
_Video_from_vertex(api_client, getv(from_object, ['_self']), to_object),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GenerateVideosResponse_from_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['videos']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['generated_videos'],
|
||||
[
|
||||
_GeneratedVideo_from_vertex(api_client, item, to_object)
|
||||
for item in getv(from_object, ['videos'])
|
||||
],
|
||||
)
|
||||
|
||||
if getv(from_object, ['raiMediaFilteredCount']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['rai_media_filtered_count'],
|
||||
getv(from_object, ['raiMediaFilteredCount']),
|
||||
)
|
||||
|
||||
if getv(from_object, ['raiMediaFilteredReasons']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['rai_media_filtered_reasons'],
|
||||
getv(from_object, ['raiMediaFilteredReasons']),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
def _GenerateVideosOperation_from_vertex(
|
||||
api_client: BaseApiClient,
|
||||
from_object: Union[dict[str, Any], object],
|
||||
parent_object: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
to_object: dict[str, Any] = {}
|
||||
if getv(from_object, ['name']) is not None:
|
||||
setv(to_object, ['name'], getv(from_object, ['name']))
|
||||
|
||||
if getv(from_object, ['metadata']) is not None:
|
||||
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
|
||||
|
||||
if getv(from_object, ['done']) is not None:
|
||||
setv(to_object, ['done'], getv(from_object, ['done']))
|
||||
|
||||
if getv(from_object, ['error']) is not None:
|
||||
setv(to_object, ['error'], getv(from_object, ['error']))
|
||||
|
||||
if getv(from_object, ['response']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['response'],
|
||||
_GenerateVideosResponse_from_vertex(
|
||||
api_client, getv(from_object, ['response']), to_object
|
||||
),
|
||||
)
|
||||
|
||||
if getv(from_object, ['response']) is not None:
|
||||
setv(
|
||||
to_object,
|
||||
['result'],
|
||||
_GenerateVideosResponse_from_vertex(
|
||||
api_client, getv(from_object, ['response']), to_object
|
||||
),
|
||||
)
|
||||
|
||||
return to_object
|
||||
|
||||
|
||||
class Operations(_api_module.BaseModule):
|
||||
|
||||
def _get_videos_operation(
|
||||
self,
|
||||
*,
|
||||
operation_name: str,
|
||||
config: Optional[types.GetOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
parameter_model = types._GetOperationParameters(
|
||||
operation_name=operation_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
request_url_dict: Optional[dict[str, str]]
|
||||
|
||||
if self._api_client.vertexai:
|
||||
request_dict = _GetOperationParameters_to_vertex(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{operationName}'.format_map(request_url_dict)
|
||||
else:
|
||||
path = '{operationName}'
|
||||
else:
|
||||
request_dict = _GetOperationParameters_to_mldev(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{operationName}'.format_map(request_url_dict)
|
||||
else:
|
||||
path = '{operationName}'
|
||||
query_params = request_dict.get('_query')
|
||||
if query_params:
|
||||
path = f'{path}?{urlencode(query_params)}'
|
||||
# TODO: remove the hack that pops config.
|
||||
request_dict.pop('config', None)
|
||||
|
||||
http_options: Optional[types.HttpOptions] = None
|
||||
if (
|
||||
parameter_model.config is not None
|
||||
and parameter_model.config.http_options is not None
|
||||
):
|
||||
http_options = parameter_model.config.http_options
|
||||
|
||||
request_dict = _common.convert_to_dict(request_dict)
|
||||
request_dict = _common.encode_unserializable_types(request_dict)
|
||||
|
||||
response_dict = self._api_client.request(
|
||||
'get', path, request_dict, http_options
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = _GenerateVideosOperation_from_vertex(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
else:
|
||||
response_dict = _GenerateVideosOperation_from_mldev(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
return_value = types.GenerateVideosOperation._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
self._api_client._verify_response(return_value)
|
||||
return return_value
|
||||
|
||||
def _fetch_predict_videos_operation(
|
||||
self,
|
||||
*,
|
||||
operation_name: str,
|
||||
resource_name: str,
|
||||
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
parameter_model = types._FetchPredictOperationParameters(
|
||||
operation_name=operation_name,
|
||||
resource_name=resource_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
request_url_dict: Optional[dict[str, str]]
|
||||
if not self._api_client.vertexai:
|
||||
raise ValueError('This method is only supported in the Vertex AI client.')
|
||||
else:
|
||||
request_dict = _FetchPredictOperationParameters_to_vertex(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{resourceName}:fetchPredictOperation'.format_map(
|
||||
request_url_dict
|
||||
)
|
||||
else:
|
||||
path = '{resourceName}:fetchPredictOperation'
|
||||
|
||||
query_params = request_dict.get('_query')
|
||||
if query_params:
|
||||
path = f'{path}?{urlencode(query_params)}'
|
||||
# TODO: remove the hack that pops config.
|
||||
request_dict.pop('config', None)
|
||||
|
||||
http_options: Optional[types.HttpOptions] = None
|
||||
if (
|
||||
parameter_model.config is not None
|
||||
and parameter_model.config.http_options is not None
|
||||
):
|
||||
http_options = parameter_model.config.http_options
|
||||
|
||||
request_dict = _common.convert_to_dict(request_dict)
|
||||
request_dict = _common.encode_unserializable_types(request_dict)
|
||||
|
||||
response_dict = self._api_client.request(
|
||||
'post', path, request_dict, http_options
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = _GenerateVideosOperation_from_vertex(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
return_value = types.GenerateVideosOperation._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
self._api_client._verify_response(return_value)
|
||||
return return_value
|
||||
|
||||
def get(
|
||||
self,
|
||||
operation: types.GenerateVideosOperation,
|
||||
*,
|
||||
config: Optional[types.GetOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
"""Gets the status of an operation."""
|
||||
# Currently, only GenerateVideosOperation is supported.
|
||||
# TODO(b/398040607): Support short form names
|
||||
operation_name = operation.name
|
||||
if not operation_name:
|
||||
raise ValueError('Operation name is empty.')
|
||||
|
||||
# TODO(b/398233524): Cast operation types
|
||||
if self._api_client.vertexai:
|
||||
resource_name = operation_name.rpartition('/operations/')[0]
|
||||
http_options = types.HttpOptions()
|
||||
if isinstance(config, dict):
|
||||
dict_options = config.get('http_options', None)
|
||||
if dict_options is not None:
|
||||
http_options = types.HttpOptions(**dict(dict_options))
|
||||
elif isinstance(config, types.GetOperationConfig) and config is not None:
|
||||
http_options = (
|
||||
config.http_options
|
||||
if config.http_options is not None
|
||||
else types.HttpOptions()
|
||||
)
|
||||
fetch_operation_config = types.FetchPredictOperationConfig(
|
||||
http_options=http_options
|
||||
)
|
||||
return self._fetch_predict_videos_operation(
|
||||
operation_name=operation_name,
|
||||
resource_name=resource_name,
|
||||
config=fetch_operation_config,
|
||||
)
|
||||
else:
|
||||
return self._get_videos_operation(
|
||||
operation_name=operation_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
class AsyncOperations(_api_module.BaseModule):
|
||||
|
||||
async def _get_videos_operation(
|
||||
self,
|
||||
*,
|
||||
operation_name: str,
|
||||
config: Optional[types.GetOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
parameter_model = types._GetOperationParameters(
|
||||
operation_name=operation_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
request_url_dict: Optional[dict[str, str]]
|
||||
|
||||
if self._api_client.vertexai:
|
||||
request_dict = _GetOperationParameters_to_vertex(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{operationName}'.format_map(request_url_dict)
|
||||
else:
|
||||
path = '{operationName}'
|
||||
else:
|
||||
request_dict = _GetOperationParameters_to_mldev(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{operationName}'.format_map(request_url_dict)
|
||||
else:
|
||||
path = '{operationName}'
|
||||
query_params = request_dict.get('_query')
|
||||
if query_params:
|
||||
path = f'{path}?{urlencode(query_params)}'
|
||||
# TODO: remove the hack that pops config.
|
||||
request_dict.pop('config', None)
|
||||
|
||||
http_options: Optional[types.HttpOptions] = None
|
||||
if (
|
||||
parameter_model.config is not None
|
||||
and parameter_model.config.http_options is not None
|
||||
):
|
||||
http_options = parameter_model.config.http_options
|
||||
|
||||
request_dict = _common.convert_to_dict(request_dict)
|
||||
request_dict = _common.encode_unserializable_types(request_dict)
|
||||
|
||||
response_dict = await self._api_client.async_request(
|
||||
'get', path, request_dict, http_options
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = _GenerateVideosOperation_from_vertex(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
else:
|
||||
response_dict = _GenerateVideosOperation_from_mldev(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
return_value = types.GenerateVideosOperation._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
self._api_client._verify_response(return_value)
|
||||
return return_value
|
||||
|
||||
async def _fetch_predict_videos_operation(
|
||||
self,
|
||||
*,
|
||||
operation_name: str,
|
||||
resource_name: str,
|
||||
config: Optional[types.FetchPredictOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
parameter_model = types._FetchPredictOperationParameters(
|
||||
operation_name=operation_name,
|
||||
resource_name=resource_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
request_url_dict: Optional[dict[str, str]]
|
||||
if not self._api_client.vertexai:
|
||||
raise ValueError('This method is only supported in the Vertex AI client.')
|
||||
else:
|
||||
request_dict = _FetchPredictOperationParameters_to_vertex(
|
||||
self._api_client, parameter_model
|
||||
)
|
||||
request_url_dict = request_dict.get('_url')
|
||||
if request_url_dict:
|
||||
path = '{resourceName}:fetchPredictOperation'.format_map(
|
||||
request_url_dict
|
||||
)
|
||||
else:
|
||||
path = '{resourceName}:fetchPredictOperation'
|
||||
|
||||
query_params = request_dict.get('_query')
|
||||
if query_params:
|
||||
path = f'{path}?{urlencode(query_params)}'
|
||||
# TODO: remove the hack that pops config.
|
||||
request_dict.pop('config', None)
|
||||
|
||||
http_options: Optional[types.HttpOptions] = None
|
||||
if (
|
||||
parameter_model.config is not None
|
||||
and parameter_model.config.http_options is not None
|
||||
):
|
||||
http_options = parameter_model.config.http_options
|
||||
|
||||
request_dict = _common.convert_to_dict(request_dict)
|
||||
request_dict = _common.encode_unserializable_types(request_dict)
|
||||
|
||||
response_dict = await self._api_client.async_request(
|
||||
'post', path, request_dict, http_options
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = _GenerateVideosOperation_from_vertex(
|
||||
self._api_client, response_dict
|
||||
)
|
||||
|
||||
return_value = types.GenerateVideosOperation._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
self._api_client._verify_response(return_value)
|
||||
return return_value
|
||||
|
||||
async def get(
|
||||
self,
|
||||
operation: types.GenerateVideosOperation,
|
||||
*,
|
||||
config: Optional[types.GetOperationConfigOrDict] = None,
|
||||
) -> types.GenerateVideosOperation:
|
||||
"""Gets the status of an operation."""
|
||||
# Currently, only GenerateVideosOperation is supported.
|
||||
operation_name = operation.name
|
||||
if not operation_name:
|
||||
raise ValueError('Operation name is empty.')
|
||||
|
||||
if self._api_client.vertexai:
|
||||
resource_name = operation_name.rpartition('/operations/')[0]
|
||||
http_options = types.HttpOptions()
|
||||
if isinstance(config, dict):
|
||||
dict_options = config.get('http_options', None)
|
||||
if dict_options is not None:
|
||||
http_options = types.HttpOptions(**dict(dict_options))
|
||||
elif isinstance(config, types.GetOperationConfig) and config is not None:
|
||||
http_options = (
|
||||
config.http_options
|
||||
if config.http_options is not None
|
||||
else types.HttpOptions()
|
||||
)
|
||||
fetch_operation_config = types.FetchPredictOperationConfig(
|
||||
http_options=http_options
|
||||
)
|
||||
return await self._fetch_predict_videos_operation(
|
||||
operation_name=operation_name,
|
||||
resource_name=resource_name,
|
||||
config=fetch_operation_config,
|
||||
)
|
||||
else:
|
||||
return await self._get_videos_operation(
|
||||
operation_name=operation_name,
|
||||
config=config,
|
||||
)
|
||||
252
.venv/lib/python3.10/site-packages/google/genai/pagers.py
Normal file
252
.venv/lib/python3.10/site-packages/google/genai/pagers.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Pagers for the GenAI List APIs."""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import copy
|
||||
from typing import Any, AsyncIterator,Awaitable, Callable, Generic, Iterator, Literal, TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
PagedItem = Literal[
|
||||
'batch_jobs', 'models', 'tuning_jobs', 'files', 'cached_contents'
|
||||
]
|
||||
|
||||
|
||||
class _BasePager(Generic[T]):
|
||||
"""Base pager class for iterating through paginated results."""
|
||||
|
||||
def _init_page(
|
||||
self,
|
||||
name: PagedItem,
|
||||
request: Callable[..., Any],
|
||||
response: Any,
|
||||
config: Any,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._request = request
|
||||
|
||||
self._page = getattr(response, self._name) or []
|
||||
self._idx = 0
|
||||
|
||||
if not config:
|
||||
request_config = {}
|
||||
elif isinstance(config, dict):
|
||||
request_config = copy.deepcopy(config)
|
||||
else:
|
||||
request_config = dict(config)
|
||||
request_config['page_token'] = getattr(response, 'next_page_token')
|
||||
self._config = request_config
|
||||
|
||||
self._page_size: int = request_config.get('page_size', len(self._page))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: PagedItem,
|
||||
request: Callable[..., Any],
|
||||
response: Any,
|
||||
config: Any,
|
||||
):
|
||||
self._init_page(name, request, response, config)
|
||||
|
||||
@property
|
||||
def page(self) -> list[T]:
|
||||
"""Returns a subset of the entire list of items.
|
||||
|
||||
For the number of items returned, see `pageSize()`.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
||||
print(f"first page: {batch_jobs_pager.page}")
|
||||
# first page: [BatchJob(name='projects/./locations/./batchPredictionJobs/1
|
||||
"""
|
||||
|
||||
return self._page
|
||||
|
||||
@property
|
||||
def name(self) -> PagedItem:
|
||||
"""Returns the type of paged item (for example, ``batch_jobs``).
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
||||
print(f"name: {batch_jobs_pager.name}")
|
||||
# name: batch_jobs
|
||||
"""
|
||||
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def page_size(self) -> int:
|
||||
"""Returns the maximum number of items fetched by the pager at one time.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
||||
print(f"page_size: {batch_jobs_pager.page_size}")
|
||||
# page_size: 5
|
||||
"""
|
||||
|
||||
return self._page_size
|
||||
|
||||
@property
|
||||
def config(self) -> dict[str, Any]:
|
||||
"""Returns the configuration when making the API request for the next page.
|
||||
|
||||
A configuration is a set of optional parameters and arguments that can be
|
||||
used to customize the API request. For example, the ``page_token`` parameter
|
||||
contains the token to request the next page.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
||||
print(f"config: {batch_jobs_pager.config}")
|
||||
# config: {'page_size': 5, 'page_token': 'AMEw9yO5jnsGnZJLHSKDFHJJu'}
|
||||
"""
|
||||
|
||||
return self._config
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the total number of items in the current page."""
|
||||
return len(self.page)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
"""Returns the item at the given index."""
|
||||
return self.page[index]
|
||||
|
||||
def _init_next_page(self, response: Any) -> None:
|
||||
"""Initializes the next page from the response.
|
||||
|
||||
This is an internal method that should be called by subclasses after
|
||||
fetching the next page.
|
||||
|
||||
Args:
|
||||
response: The response object from the API request.
|
||||
"""
|
||||
self._init_page(self.name, self._request, response, self.config)
|
||||
|
||||
|
||||
class Pager(_BasePager[T]):
|
||||
"""Pager class for iterating through paginated results."""
|
||||
|
||||
def __next__(self) -> T:
|
||||
"""Returns the next item."""
|
||||
if self._idx >= len(self):
|
||||
try:
|
||||
self.next_page()
|
||||
except IndexError:
|
||||
raise StopIteration
|
||||
|
||||
item = self.page[self._idx]
|
||||
self._idx += 1
|
||||
return item
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
"""Returns an iterator over the items."""
|
||||
self._idx = 0
|
||||
return self
|
||||
|
||||
def next_page(self) -> list[T]:
|
||||
"""Fetches the next page of items. This makes a new API request.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = client.batches.list(config={'page_size': 5})
|
||||
print(f"current page: {batch_jobs_pager.page}")
|
||||
batch_jobs_pager.next_page()
|
||||
print(f"next page: {batch_jobs_pager.page}")
|
||||
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
|
||||
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
|
||||
"""
|
||||
|
||||
if not self.config.get('page_token'):
|
||||
raise IndexError('No more pages to fetch.')
|
||||
|
||||
response = self._request(config=self.config)
|
||||
self._init_next_page(response)
|
||||
return self.page
|
||||
|
||||
|
||||
class AsyncPager(_BasePager[T]):
|
||||
"""AsyncPager class for iterating through paginated results."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: PagedItem,
|
||||
request: Callable[..., Awaitable[Any]],
|
||||
response: Any,
|
||||
config: Any,
|
||||
):
|
||||
super().__init__(name, request, response, config)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[T]:
|
||||
"""Returns an async iterator over the items."""
|
||||
self._idx = 0
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> T:
|
||||
"""Returns the next item asynchronously."""
|
||||
if self._idx >= len(self):
|
||||
try:
|
||||
await self.next_page()
|
||||
except IndexError:
|
||||
raise StopAsyncIteration
|
||||
|
||||
item = self.page[self._idx]
|
||||
self._idx += 1
|
||||
return item
|
||||
|
||||
async def next_page(self) -> list[T]:
|
||||
"""Fetches the next page of items asynchronously.
|
||||
|
||||
This makes a new API request.
|
||||
|
||||
Returns:
|
||||
The next page of items.
|
||||
|
||||
Raises:
|
||||
IndexError: No more pages to fetch.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
batch_jobs_pager = await client.aio.batches.list(config={'page_size': 5})
|
||||
print(f"current page: {batch_jobs_pager.page}")
|
||||
await batch_jobs_pager.next_page()
|
||||
print(f"next page: {batch_jobs_pager.page}")
|
||||
# current page: [BatchJob(name='projects/.../batchPredictionJobs/1
|
||||
# next page: [BatchJob(name='projects/.../batchPredictionJobs/6
|
||||
"""
|
||||
|
||||
if not self.config.get('page_token'):
|
||||
raise IndexError('No more pages to fetch.')
|
||||
|
||||
response = await self._request(config=self.config)
|
||||
self._init_next_page(response)
|
||||
return self.page
|
||||
1
.venv/lib/python3.10/site-packages/google/genai/py.typed
Normal file
1
.venv/lib/python3.10/site-packages/google/genai/py.typed
Normal file
@@ -0,0 +1 @@
|
||||
# see: https://peps.python.org/pep-0561/
|
||||
1562
.venv/lib/python3.10/site-packages/google/genai/tunings.py
Normal file
1562
.venv/lib/python3.10/site-packages/google/genai/tunings.py
Normal file
File diff suppressed because it is too large
Load Diff
10838
.venv/lib/python3.10/site-packages/google/genai/types.py
Normal file
10838
.venv/lib/python3.10/site-packages/google/genai/types.py
Normal file
File diff suppressed because it is too large
Load Diff
16
.venv/lib/python3.10/site-packages/google/genai/version.py
Normal file
16
.venv/lib/python3.10/site-packages/google/genai/version.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
__version__ = '1.12.1' # x-release-please-version
|
||||
Reference in New Issue
Block a user