Files
evo-ai/.venv/lib/python3.10/site-packages/google/genai/_extra_utils.py
2025-04-25 15:30:54 -03:00

404 lines
14 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Extra utils depending on types that are shared between sync and async modules."""
import inspect
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
import pydantic
from . import _common
from . import errors
from . import types
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
logger = logging.getLogger('google_genai.models')
def _create_generate_content_config_model(
config: types.GenerateContentConfigOrDict,
) -> types.GenerateContentConfig:
if isinstance(config, dict):
return types.GenerateContentConfig(**config)
else:
return config
def format_destination(
src: str,
config: Optional[types.CreateBatchJobConfigOrDict] = None,
) -> types.CreateBatchJobConfig:
"""Formats the destination uri based on the source uri."""
config = (
types._CreateBatchJobParameters(config=config).config
or types.CreateBatchJobConfig()
)
unique_name = None
if not config.display_name:
unique_name = _common.timestamped_unique_name()
config.display_name = f'genai_batch_job_{unique_name}'
if not config.dest:
if src.startswith('gs://') and src.endswith('.jsonl'):
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
# uri prefix will be "gs://bucket/path/to/src/dest".
config.dest = f'{src[:-6]}/dest'
elif src.startswith('bq://'):
# If source uri is "bq://project.dataset.src", then the destination
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
unique_name = unique_name or _common.timestamped_unique_name()
config.dest = f'{src}_dest_{unique_name}'
else:
raise ValueError(f'Unsupported source: {src}')
return config
def get_function_map(
config: Optional[types.GenerateContentConfigOrDict] = None,
is_caller_method_async: bool = False,
) -> dict[str, Callable[..., Any]]:
"""Returns a function map from the config."""
function_map: dict[str, Callable[..., Any]] = {}
if not config:
return function_map
config_model = _create_generate_content_config_model(config)
if config_model.tools:
for tool in config_model.tools:
if callable(tool):
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
raise errors.UnsupportedFunctionError(
f'Function {tool.__name__} is a coroutine function, which is not'
' supported for automatic function calling. Please manually'
f' invoke {tool.__name__} to get the function response.'
)
function_map[tool.__name__] = tool
return function_map
def convert_number_values_for_dict_function_call_args(
args: dict[str, Any],
) -> dict[str, Any]:
"""Converts float values in dict with no decimal to integers."""
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}
def convert_number_values_for_function_call_args(
args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
"""Converts float values with no decimal to integers."""
if isinstance(args, float) and args.is_integer():
return int(args)
if isinstance(args, dict):
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}
if isinstance(args, list):
return [
convert_number_values_for_function_call_args(value) for value in args
]
return args
def is_annotation_pydantic_model(annotation: Any) -> bool:
try:
return inspect.isclass(annotation) and issubclass(
annotation, pydantic.BaseModel
)
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
# results with versions above. for example, inspect.isclass(dict[str, int]) is
# True in 3.10 and below but False in 3.11 and above.
except TypeError:
return False
def convert_if_exist_pydantic_model(
value: Any, annotation: Any, param_name: str, func_name: str
) -> Any:
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
try:
return annotation(**value)
except pydantic.ValidationError as e:
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} is not compatible with parameter annotation'
f' {annotation}, due to error {e}'
)
if isinstance(value, list) and get_origin(annotation) == list:
item_type = get_args(annotation)[0]
return [
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
for item in value
]
if isinstance(value, dict) and get_origin(annotation) == dict:
_, value_type = get_args(annotation)
return {
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
for k, v in value.items()
}
# example 1: typing.Union[int, float]
# example 2: int | float equivalent to UnionType[int, float]
if get_origin(annotation) in (Union, UnionType):
for arg in get_args(annotation):
if (
(get_args(arg) and get_origin(arg) is list)
or isinstance(value, arg)
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
):
try:
return convert_if_exist_pydantic_model(
value, arg, param_name, func_name
)
# do not raise here because there could be multiple pydantic model types
# in the union type.
except pydantic.ValidationError:
continue
# if none of the union type is matched, raise error
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function'
f' {func_name} from function call part because function call argument'
f' value {value} cannot be converted to parameter annotation'
f' {annotation}.'
)
# the only exception for value and annotation type to be different is int and
# float. see convert_number_values_for_function_call_args function for context
if isinstance(value, int) and annotation is float:
return value
if not isinstance(value, annotation):
raise errors.UnknownFunctionCallArgumentError(
f'Failed to parse parameter {param_name} for function {func_name} from'
f' function call part because function call argument value {value} is'
f' not compatible with parameter annotation {annotation}.'
)
return value
def convert_argument_from_function(
args: dict[str, Any], function: Callable[..., Any]
) -> dict[str, Any]:
signature = inspect.signature(function)
func_name = function.__name__
converted_args = {}
for param_name, param in signature.parameters.items():
if param_name in args:
converted_args[param_name] = convert_if_exist_pydantic_model(
args[param_name],
param.annotation,
param_name,
func_name,
)
return converted_args
def invoke_function_from_dict_args(
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
) -> Any:
converted_args = convert_argument_from_function(args, function_to_invoke)
try:
return function_to_invoke(**converted_args)
except Exception as e:
raise errors.FunctionInvocationError(
f'Failed to invoke function {function_to_invoke.__name__} with'
f' converted arguments {converted_args} from model returned function'
f' call argument {args} because of error {e}'
)
async def invoke_function_from_dict_args_async(
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
) -> Any:
converted_args = convert_argument_from_function(args, function_to_invoke)
try:
return await function_to_invoke(**converted_args)
except Exception as e:
raise errors.FunctionInvocationError(
f'Failed to invoke function {function_to_invoke.__name__} with'
f' converted arguments {converted_args} from model returned function'
f' call argument {args} because of error {e}'
)
def get_function_response_parts(
response: types.GenerateContentResponse,
function_map: dict[str, Callable[..., Any]],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
if (
response.candidates is not None
and isinstance(response.candidates[0].content, types.Content)
and response.candidates[0].content.parts is not None
):
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
if func_name is not None and part.function_call.args is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
part.function_call.args
)
func_response: dict[str, Any]
try:
func_response = {
'result': invoke_function_from_dict_args(args, func)
}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)
func_response_parts.append(func_response_part)
return func_response_parts
async def get_function_response_parts_async(
response: types.GenerateContentResponse,
function_map: dict[str, Callable[..., Any]],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
if (
response.candidates is not None
and isinstance(response.candidates[0].content, types.Content)
and response.candidates[0].content.parts is not None
):
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
if func_name is not None and part.function_call.args is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
part.function_call.args
)
func_response: dict[str, Any]
try:
if inspect.iscoroutinefunction(func):
func_response = {
'result': await invoke_function_from_dict_args_async(args, func)
}
else:
func_response = {
'result': invoke_function_from_dict_args(args, func)
}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)
func_response_parts.append(func_response_part)
return func_response_parts
def should_disable_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
"""Returns whether automatic function calling is enabled."""
if not config:
return False
config_model = _create_generate_content_config_model(config)
# If max_remote_calls is less or equal to 0, warn and disable AFC.
if (
config_model
and config_model.automatic_function_calling
and config_model.automatic_function_calling.maximum_remote_calls
is not None
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
):
logger.warning(
'max_remote_calls in automatic_function_calling_config'
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
' less than or equal to 0. Disabling automatic function calling.'
' Please set max_remote_calls to a positive integer.'
)
return True
# Default to enable AFC if not specified.
if (
not config_model.automatic_function_calling
or config_model.automatic_function_calling.disable is None
):
return False
if (
config_model.automatic_function_calling.disable
and config_model.automatic_function_calling.maximum_remote_calls
is not None
# exclude the case where max_remote_calls is set to 10 by default.
and 'maximum_remote_calls'
in config_model.automatic_function_calling.model_fields_set
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
):
logger.warning(
'`automatic_function_calling.disable` is set to `True`. And'
' `automatic_function_calling.maximum_remote_calls` is a'
' positive number'
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
' Disabling automatic function calling. If you want to enable'
' automatic function calling, please set'
' `automatic_function_calling.disable` to `False` or leave it unset,'
' and set `automatic_function_calling.maximum_remote_calls` to a'
' positive integer or leave'
' `automatic_function_calling.maximum_remote_calls` unset.'
)
return config_model.automatic_function_calling.disable
def get_max_remote_calls_afc(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> int:
if not config:
return _DEFAULT_MAX_REMOTE_CALLS_AFC
"""Returns the remaining remote calls for automatic function calling."""
if should_disable_afc(config):
raise ValueError(
'automatic function calling is not enabled, but SDK is trying to get'
' max remote calls.'
)
config_model = _create_generate_content_config_model(config)
if (
not config_model.automatic_function_calling
or config_model.automatic_function_calling.maximum_remote_calls is None
):
return _DEFAULT_MAX_REMOTE_CALLS_AFC
return int(config_model.automatic_function_calling.maximum_remote_calls)
def should_append_afc_history(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
if not config:
return True
config_model = _create_generate_content_config_model(config)
if not config_model.automatic_function_calling:
return True
return not config_model.automatic_function_calling.ignore_call_history