404 lines
14 KiB
Python
404 lines
14 KiB
Python
# Copyright 2024 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""Extra utils depending on types that are shared between sync and async modules."""
|
|
|
|
import inspect
|
|
import logging
|
|
import sys
|
|
import typing
|
|
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
|
|
|
import pydantic
|
|
|
|
from . import _common
|
|
from . import errors
|
|
from . import types
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from types import UnionType
|
|
else:
|
|
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
|
|
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
|
|
|
logger = logging.getLogger('google_genai.models')
|
|
|
|
|
|
def _create_generate_content_config_model(
|
|
config: types.GenerateContentConfigOrDict,
|
|
) -> types.GenerateContentConfig:
|
|
if isinstance(config, dict):
|
|
return types.GenerateContentConfig(**config)
|
|
else:
|
|
return config
|
|
|
|
|
|
def format_destination(
|
|
src: str,
|
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
|
) -> types.CreateBatchJobConfig:
|
|
"""Formats the destination uri based on the source uri."""
|
|
config = (
|
|
types._CreateBatchJobParameters(config=config).config
|
|
or types.CreateBatchJobConfig()
|
|
)
|
|
|
|
unique_name = None
|
|
if not config.display_name:
|
|
unique_name = _common.timestamped_unique_name()
|
|
config.display_name = f'genai_batch_job_{unique_name}'
|
|
|
|
if not config.dest:
|
|
if src.startswith('gs://') and src.endswith('.jsonl'):
|
|
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
|
|
# uri prefix will be "gs://bucket/path/to/src/dest".
|
|
config.dest = f'{src[:-6]}/dest'
|
|
elif src.startswith('bq://'):
|
|
# If source uri is "bq://project.dataset.src", then the destination
|
|
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
|
|
unique_name = unique_name or _common.timestamped_unique_name()
|
|
config.dest = f'{src}_dest_{unique_name}'
|
|
else:
|
|
raise ValueError(f'Unsupported source: {src}')
|
|
return config
|
|
|
|
|
|
def get_function_map(
|
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
|
is_caller_method_async: bool = False,
|
|
) -> dict[str, Callable[..., Any]]:
|
|
"""Returns a function map from the config."""
|
|
function_map: dict[str, Callable[..., Any]] = {}
|
|
if not config:
|
|
return function_map
|
|
config_model = _create_generate_content_config_model(config)
|
|
if config_model.tools:
|
|
for tool in config_model.tools:
|
|
if callable(tool):
|
|
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
|
|
raise errors.UnsupportedFunctionError(
|
|
f'Function {tool.__name__} is a coroutine function, which is not'
|
|
' supported for automatic function calling. Please manually'
|
|
f' invoke {tool.__name__} to get the function response.'
|
|
)
|
|
function_map[tool.__name__] = tool
|
|
return function_map
|
|
|
|
|
|
def convert_number_values_for_dict_function_call_args(
|
|
args: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
"""Converts float values in dict with no decimal to integers."""
|
|
return {
|
|
key: convert_number_values_for_function_call_args(value)
|
|
for key, value in args.items()
|
|
}
|
|
|
|
|
|
def convert_number_values_for_function_call_args(
|
|
args: Union[dict[str, object], list[object], object],
|
|
) -> Union[dict[str, object], list[object], object]:
|
|
"""Converts float values with no decimal to integers."""
|
|
if isinstance(args, float) and args.is_integer():
|
|
return int(args)
|
|
if isinstance(args, dict):
|
|
return {
|
|
key: convert_number_values_for_function_call_args(value)
|
|
for key, value in args.items()
|
|
}
|
|
if isinstance(args, list):
|
|
return [
|
|
convert_number_values_for_function_call_args(value) for value in args
|
|
]
|
|
return args
|
|
|
|
|
|
def is_annotation_pydantic_model(annotation: Any) -> bool:
|
|
try:
|
|
return inspect.isclass(annotation) and issubclass(
|
|
annotation, pydantic.BaseModel
|
|
)
|
|
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
|
|
# results with versions above. for example, inspect.isclass(dict[str, int]) is
|
|
# True in 3.10 and below but False in 3.11 and above.
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def convert_if_exist_pydantic_model(
|
|
value: Any, annotation: Any, param_name: str, func_name: str
|
|
) -> Any:
|
|
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
|
|
try:
|
|
return annotation(**value)
|
|
except pydantic.ValidationError as e:
|
|
raise errors.UnknownFunctionCallArgumentError(
|
|
f'Failed to parse parameter {param_name} for function'
|
|
f' {func_name} from function call part because function call argument'
|
|
f' value {value} is not compatible with parameter annotation'
|
|
f' {annotation}, due to error {e}'
|
|
)
|
|
if isinstance(value, list) and get_origin(annotation) == list:
|
|
item_type = get_args(annotation)[0]
|
|
return [
|
|
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
|
|
for item in value
|
|
]
|
|
if isinstance(value, dict) and get_origin(annotation) == dict:
|
|
_, value_type = get_args(annotation)
|
|
return {
|
|
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
|
|
for k, v in value.items()
|
|
}
|
|
# example 1: typing.Union[int, float]
|
|
# example 2: int | float equivalent to UnionType[int, float]
|
|
if get_origin(annotation) in (Union, UnionType):
|
|
for arg in get_args(annotation):
|
|
if (
|
|
(get_args(arg) and get_origin(arg) is list)
|
|
or isinstance(value, arg)
|
|
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
|
|
):
|
|
try:
|
|
return convert_if_exist_pydantic_model(
|
|
value, arg, param_name, func_name
|
|
)
|
|
# do not raise here because there could be multiple pydantic model types
|
|
# in the union type.
|
|
except pydantic.ValidationError:
|
|
continue
|
|
# if none of the union type is matched, raise error
|
|
raise errors.UnknownFunctionCallArgumentError(
|
|
f'Failed to parse parameter {param_name} for function'
|
|
f' {func_name} from function call part because function call argument'
|
|
f' value {value} cannot be converted to parameter annotation'
|
|
f' {annotation}.'
|
|
)
|
|
# the only exception for value and annotation type to be different is int and
|
|
# float. see convert_number_values_for_function_call_args function for context
|
|
if isinstance(value, int) and annotation is float:
|
|
return value
|
|
if not isinstance(value, annotation):
|
|
raise errors.UnknownFunctionCallArgumentError(
|
|
f'Failed to parse parameter {param_name} for function {func_name} from'
|
|
f' function call part because function call argument value {value} is'
|
|
f' not compatible with parameter annotation {annotation}.'
|
|
)
|
|
return value
|
|
|
|
|
|
def convert_argument_from_function(
|
|
args: dict[str, Any], function: Callable[..., Any]
|
|
) -> dict[str, Any]:
|
|
signature = inspect.signature(function)
|
|
func_name = function.__name__
|
|
converted_args = {}
|
|
for param_name, param in signature.parameters.items():
|
|
if param_name in args:
|
|
converted_args[param_name] = convert_if_exist_pydantic_model(
|
|
args[param_name],
|
|
param.annotation,
|
|
param_name,
|
|
func_name,
|
|
)
|
|
return converted_args
|
|
|
|
|
|
def invoke_function_from_dict_args(
|
|
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
|
) -> Any:
|
|
converted_args = convert_argument_from_function(args, function_to_invoke)
|
|
try:
|
|
return function_to_invoke(**converted_args)
|
|
except Exception as e:
|
|
raise errors.FunctionInvocationError(
|
|
f'Failed to invoke function {function_to_invoke.__name__} with'
|
|
f' converted arguments {converted_args} from model returned function'
|
|
f' call argument {args} because of error {e}'
|
|
)
|
|
|
|
|
|
async def invoke_function_from_dict_args_async(
|
|
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
|
) -> Any:
|
|
converted_args = convert_argument_from_function(args, function_to_invoke)
|
|
try:
|
|
return await function_to_invoke(**converted_args)
|
|
except Exception as e:
|
|
raise errors.FunctionInvocationError(
|
|
f'Failed to invoke function {function_to_invoke.__name__} with'
|
|
f' converted arguments {converted_args} from model returned function'
|
|
f' call argument {args} because of error {e}'
|
|
)
|
|
|
|
|
|
def get_function_response_parts(
|
|
response: types.GenerateContentResponse,
|
|
function_map: dict[str, Callable[..., Any]],
|
|
) -> list[types.Part]:
|
|
"""Returns the function response parts from the response."""
|
|
func_response_parts = []
|
|
if (
|
|
response.candidates is not None
|
|
and isinstance(response.candidates[0].content, types.Content)
|
|
and response.candidates[0].content.parts is not None
|
|
):
|
|
for part in response.candidates[0].content.parts:
|
|
if not part.function_call:
|
|
continue
|
|
func_name = part.function_call.name
|
|
if func_name is not None and part.function_call.args is not None:
|
|
func = function_map[func_name]
|
|
args = convert_number_values_for_dict_function_call_args(
|
|
part.function_call.args
|
|
)
|
|
func_response: dict[str, Any]
|
|
try:
|
|
func_response = {
|
|
'result': invoke_function_from_dict_args(args, func)
|
|
}
|
|
except Exception as e: # pylint: disable=broad-except
|
|
func_response = {'error': str(e)}
|
|
func_response_part = types.Part.from_function_response(
|
|
name=func_name, response=func_response
|
|
)
|
|
func_response_parts.append(func_response_part)
|
|
return func_response_parts
|
|
|
|
async def get_function_response_parts_async(
|
|
response: types.GenerateContentResponse,
|
|
function_map: dict[str, Callable[..., Any]],
|
|
) -> list[types.Part]:
|
|
"""Returns the function response parts from the response."""
|
|
func_response_parts = []
|
|
if (
|
|
response.candidates is not None
|
|
and isinstance(response.candidates[0].content, types.Content)
|
|
and response.candidates[0].content.parts is not None
|
|
):
|
|
for part in response.candidates[0].content.parts:
|
|
if not part.function_call:
|
|
continue
|
|
func_name = part.function_call.name
|
|
if func_name is not None and part.function_call.args is not None:
|
|
func = function_map[func_name]
|
|
args = convert_number_values_for_dict_function_call_args(
|
|
part.function_call.args
|
|
)
|
|
func_response: dict[str, Any]
|
|
try:
|
|
if inspect.iscoroutinefunction(func):
|
|
func_response = {
|
|
'result': await invoke_function_from_dict_args_async(args, func)
|
|
}
|
|
else:
|
|
func_response = {
|
|
'result': invoke_function_from_dict_args(args, func)
|
|
}
|
|
except Exception as e: # pylint: disable=broad-except
|
|
func_response = {'error': str(e)}
|
|
func_response_part = types.Part.from_function_response(
|
|
name=func_name, response=func_response
|
|
)
|
|
func_response_parts.append(func_response_part)
|
|
return func_response_parts
|
|
|
|
|
|
def should_disable_afc(
|
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
|
) -> bool:
|
|
"""Returns whether automatic function calling is enabled."""
|
|
if not config:
|
|
return False
|
|
config_model = _create_generate_content_config_model(config)
|
|
# If max_remote_calls is less or equal to 0, warn and disable AFC.
|
|
if (
|
|
config_model
|
|
and config_model.automatic_function_calling
|
|
and config_model.automatic_function_calling.maximum_remote_calls
|
|
is not None
|
|
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
|
):
|
|
logger.warning(
|
|
'max_remote_calls in automatic_function_calling_config'
|
|
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
|
|
' less than or equal to 0. Disabling automatic function calling.'
|
|
' Please set max_remote_calls to a positive integer.'
|
|
)
|
|
return True
|
|
|
|
# Default to enable AFC if not specified.
|
|
if (
|
|
not config_model.automatic_function_calling
|
|
or config_model.automatic_function_calling.disable is None
|
|
):
|
|
return False
|
|
|
|
if (
|
|
config_model.automatic_function_calling.disable
|
|
and config_model.automatic_function_calling.maximum_remote_calls
|
|
is not None
|
|
# exclude the case where max_remote_calls is set to 10 by default.
|
|
and 'maximum_remote_calls'
|
|
in config_model.automatic_function_calling.model_fields_set
|
|
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
|
):
|
|
logger.warning(
|
|
'`automatic_function_calling.disable` is set to `True`. And'
|
|
' `automatic_function_calling.maximum_remote_calls` is a'
|
|
' positive number'
|
|
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
|
|
' Disabling automatic function calling. If you want to enable'
|
|
' automatic function calling, please set'
|
|
' `automatic_function_calling.disable` to `False` or leave it unset,'
|
|
' and set `automatic_function_calling.maximum_remote_calls` to a'
|
|
' positive integer or leave'
|
|
' `automatic_function_calling.maximum_remote_calls` unset.'
|
|
)
|
|
|
|
return config_model.automatic_function_calling.disable
|
|
|
|
|
|
def get_max_remote_calls_afc(
|
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
|
) -> int:
|
|
if not config:
|
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
|
"""Returns the remaining remote calls for automatic function calling."""
|
|
if should_disable_afc(config):
|
|
raise ValueError(
|
|
'automatic function calling is not enabled, but SDK is trying to get'
|
|
' max remote calls.'
|
|
)
|
|
config_model = _create_generate_content_config_model(config)
|
|
if (
|
|
not config_model.automatic_function_calling
|
|
or config_model.automatic_function_calling.maximum_remote_calls is None
|
|
):
|
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
|
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
|
|
|
|
|
def should_append_afc_history(
|
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
|
) -> bool:
|
|
if not config:
|
|
return True
|
|
config_model = _create_generate_content_config_model(config)
|
|
if not config_model.automatic_function_calling:
|
|
return True
|
|
return not config_model.automatic_function_calling.ignore_call_history
|