structure saas with tools
This commit is contained in:
403
.venv/lib/python3.10/site-packages/google/genai/_extra_utils.py
Normal file
403
.venv/lib/python3.10/site-packages/google/genai/_extra_utils.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Extra utils depending on types that are shared between sync and async modules."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import _common
|
||||
from . import errors
|
||||
from . import types
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
||||
|
||||
logger = logging.getLogger('google_genai.models')
|
||||
|
||||
|
||||
def _create_generate_content_config_model(
|
||||
config: types.GenerateContentConfigOrDict,
|
||||
) -> types.GenerateContentConfig:
|
||||
if isinstance(config, dict):
|
||||
return types.GenerateContentConfig(**config)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def format_destination(
|
||||
src: str,
|
||||
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
||||
) -> types.CreateBatchJobConfig:
|
||||
"""Formats the destination uri based on the source uri."""
|
||||
config = (
|
||||
types._CreateBatchJobParameters(config=config).config
|
||||
or types.CreateBatchJobConfig()
|
||||
)
|
||||
|
||||
unique_name = None
|
||||
if not config.display_name:
|
||||
unique_name = _common.timestamped_unique_name()
|
||||
config.display_name = f'genai_batch_job_{unique_name}'
|
||||
|
||||
if not config.dest:
|
||||
if src.startswith('gs://') and src.endswith('.jsonl'):
|
||||
# If source uri is "gs://bucket/path/to/src.jsonl", then the destination
|
||||
# uri prefix will be "gs://bucket/path/to/src/dest".
|
||||
config.dest = f'{src[:-6]}/dest'
|
||||
elif src.startswith('bq://'):
|
||||
# If source uri is "bq://project.dataset.src", then the destination
|
||||
# uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
|
||||
unique_name = unique_name or _common.timestamped_unique_name()
|
||||
config.dest = f'{src}_dest_{unique_name}'
|
||||
else:
|
||||
raise ValueError(f'Unsupported source: {src}')
|
||||
return config
|
||||
|
||||
|
||||
def get_function_map(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
is_caller_method_async: bool = False,
|
||||
) -> dict[str, Callable[..., Any]]:
|
||||
"""Returns a function map from the config."""
|
||||
function_map: dict[str, Callable[..., Any]] = {}
|
||||
if not config:
|
||||
return function_map
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if config_model.tools:
|
||||
for tool in config_model.tools:
|
||||
if callable(tool):
|
||||
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
|
||||
raise errors.UnsupportedFunctionError(
|
||||
f'Function {tool.__name__} is a coroutine function, which is not'
|
||||
' supported for automatic function calling. Please manually'
|
||||
f' invoke {tool.__name__} to get the function response.'
|
||||
)
|
||||
function_map[tool.__name__] = tool
|
||||
return function_map
|
||||
|
||||
|
||||
def convert_number_values_for_dict_function_call_args(
|
||||
args: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Converts float values in dict with no decimal to integers."""
|
||||
return {
|
||||
key: convert_number_values_for_function_call_args(value)
|
||||
for key, value in args.items()
|
||||
}
|
||||
|
||||
|
||||
def convert_number_values_for_function_call_args(
|
||||
args: Union[dict[str, object], list[object], object],
|
||||
) -> Union[dict[str, object], list[object], object]:
|
||||
"""Converts float values with no decimal to integers."""
|
||||
if isinstance(args, float) and args.is_integer():
|
||||
return int(args)
|
||||
if isinstance(args, dict):
|
||||
return {
|
||||
key: convert_number_values_for_function_call_args(value)
|
||||
for key, value in args.items()
|
||||
}
|
||||
if isinstance(args, list):
|
||||
return [
|
||||
convert_number_values_for_function_call_args(value) for value in args
|
||||
]
|
||||
return args
|
||||
|
||||
|
||||
def is_annotation_pydantic_model(annotation: Any) -> bool:
|
||||
try:
|
||||
return inspect.isclass(annotation) and issubclass(
|
||||
annotation, pydantic.BaseModel
|
||||
)
|
||||
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
|
||||
# results with versions above. for example, inspect.isclass(dict[str, int]) is
|
||||
# True in 3.10 and below but False in 3.11 and above.
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def convert_if_exist_pydantic_model(
|
||||
value: Any, annotation: Any, param_name: str, func_name: str
|
||||
) -> Any:
|
||||
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
|
||||
try:
|
||||
return annotation(**value)
|
||||
except pydantic.ValidationError as e:
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function'
|
||||
f' {func_name} from function call part because function call argument'
|
||||
f' value {value} is not compatible with parameter annotation'
|
||||
f' {annotation}, due to error {e}'
|
||||
)
|
||||
if isinstance(value, list) and get_origin(annotation) == list:
|
||||
item_type = get_args(annotation)[0]
|
||||
return [
|
||||
convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
|
||||
for item in value
|
||||
]
|
||||
if isinstance(value, dict) and get_origin(annotation) == dict:
|
||||
_, value_type = get_args(annotation)
|
||||
return {
|
||||
k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
|
||||
for k, v in value.items()
|
||||
}
|
||||
# example 1: typing.Union[int, float]
|
||||
# example 2: int | float equivalent to UnionType[int, float]
|
||||
if get_origin(annotation) in (Union, UnionType):
|
||||
for arg in get_args(annotation):
|
||||
if (
|
||||
(get_args(arg) and get_origin(arg) is list)
|
||||
or isinstance(value, arg)
|
||||
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
|
||||
):
|
||||
try:
|
||||
return convert_if_exist_pydantic_model(
|
||||
value, arg, param_name, func_name
|
||||
)
|
||||
# do not raise here because there could be multiple pydantic model types
|
||||
# in the union type.
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
# if none of the union type is matched, raise error
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function'
|
||||
f' {func_name} from function call part because function call argument'
|
||||
f' value {value} cannot be converted to parameter annotation'
|
||||
f' {annotation}.'
|
||||
)
|
||||
# the only exception for value and annotation type to be different is int and
|
||||
# float. see convert_number_values_for_function_call_args function for context
|
||||
if isinstance(value, int) and annotation is float:
|
||||
return value
|
||||
if not isinstance(value, annotation):
|
||||
raise errors.UnknownFunctionCallArgumentError(
|
||||
f'Failed to parse parameter {param_name} for function {func_name} from'
|
||||
f' function call part because function call argument value {value} is'
|
||||
f' not compatible with parameter annotation {annotation}.'
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def convert_argument_from_function(
|
||||
args: dict[str, Any], function: Callable[..., Any]
|
||||
) -> dict[str, Any]:
|
||||
signature = inspect.signature(function)
|
||||
func_name = function.__name__
|
||||
converted_args = {}
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param_name in args:
|
||||
converted_args[param_name] = convert_if_exist_pydantic_model(
|
||||
args[param_name],
|
||||
param.annotation,
|
||||
param_name,
|
||||
func_name,
|
||||
)
|
||||
return converted_args
|
||||
|
||||
|
||||
def invoke_function_from_dict_args(
|
||||
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
||||
) -> Any:
|
||||
converted_args = convert_argument_from_function(args, function_to_invoke)
|
||||
try:
|
||||
return function_to_invoke(**converted_args)
|
||||
except Exception as e:
|
||||
raise errors.FunctionInvocationError(
|
||||
f'Failed to invoke function {function_to_invoke.__name__} with'
|
||||
f' converted arguments {converted_args} from model returned function'
|
||||
f' call argument {args} because of error {e}'
|
||||
)
|
||||
|
||||
|
||||
async def invoke_function_from_dict_args_async(
|
||||
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
||||
) -> Any:
|
||||
converted_args = convert_argument_from_function(args, function_to_invoke)
|
||||
try:
|
||||
return await function_to_invoke(**converted_args)
|
||||
except Exception as e:
|
||||
raise errors.FunctionInvocationError(
|
||||
f'Failed to invoke function {function_to_invoke.__name__} with'
|
||||
f' converted arguments {converted_args} from model returned function'
|
||||
f' call argument {args} because of error {e}'
|
||||
)
|
||||
|
||||
|
||||
def get_function_response_parts(
|
||||
response: types.GenerateContentResponse,
|
||||
function_map: dict[str, Callable[..., Any]],
|
||||
) -> list[types.Part]:
|
||||
"""Returns the function response parts from the response."""
|
||||
func_response_parts = []
|
||||
if (
|
||||
response.candidates is not None
|
||||
and isinstance(response.candidates[0].content, types.Content)
|
||||
and response.candidates[0].content.parts is not None
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
func_name = part.function_call.name
|
||||
if func_name is not None and part.function_call.args is not None:
|
||||
func = function_map[func_name]
|
||||
args = convert_number_values_for_dict_function_call_args(
|
||||
part.function_call.args
|
||||
)
|
||||
func_response: dict[str, Any]
|
||||
try:
|
||||
func_response = {
|
||||
'result': invoke_function_from_dict_args(args, func)
|
||||
}
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
func_response = {'error': str(e)}
|
||||
func_response_part = types.Part.from_function_response(
|
||||
name=func_name, response=func_response
|
||||
)
|
||||
func_response_parts.append(func_response_part)
|
||||
return func_response_parts
|
||||
|
||||
async def get_function_response_parts_async(
|
||||
response: types.GenerateContentResponse,
|
||||
function_map: dict[str, Callable[..., Any]],
|
||||
) -> list[types.Part]:
|
||||
"""Returns the function response parts from the response."""
|
||||
func_response_parts = []
|
||||
if (
|
||||
response.candidates is not None
|
||||
and isinstance(response.candidates[0].content, types.Content)
|
||||
and response.candidates[0].content.parts is not None
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if not part.function_call:
|
||||
continue
|
||||
func_name = part.function_call.name
|
||||
if func_name is not None and part.function_call.args is not None:
|
||||
func = function_map[func_name]
|
||||
args = convert_number_values_for_dict_function_call_args(
|
||||
part.function_call.args
|
||||
)
|
||||
func_response: dict[str, Any]
|
||||
try:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
func_response = {
|
||||
'result': await invoke_function_from_dict_args_async(args, func)
|
||||
}
|
||||
else:
|
||||
func_response = {
|
||||
'result': invoke_function_from_dict_args(args, func)
|
||||
}
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
func_response = {'error': str(e)}
|
||||
func_response_part = types.Part.from_function_response(
|
||||
name=func_name, response=func_response
|
||||
)
|
||||
func_response_parts.append(func_response_part)
|
||||
return func_response_parts
|
||||
|
||||
|
||||
def should_disable_afc(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> bool:
|
||||
"""Returns whether automatic function calling is enabled."""
|
||||
if not config:
|
||||
return False
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
# If max_remote_calls is less or equal to 0, warn and disable AFC.
|
||||
if (
|
||||
config_model
|
||||
and config_model.automatic_function_calling
|
||||
and config_model.automatic_function_calling.maximum_remote_calls
|
||||
is not None
|
||||
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
||||
):
|
||||
logger.warning(
|
||||
'max_remote_calls in automatic_function_calling_config'
|
||||
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
|
||||
' less than or equal to 0. Disabling automatic function calling.'
|
||||
' Please set max_remote_calls to a positive integer.'
|
||||
)
|
||||
return True
|
||||
|
||||
# Default to enable AFC if not specified.
|
||||
if (
|
||||
not config_model.automatic_function_calling
|
||||
or config_model.automatic_function_calling.disable is None
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
config_model.automatic_function_calling.disable
|
||||
and config_model.automatic_function_calling.maximum_remote_calls
|
||||
is not None
|
||||
# exclude the case where max_remote_calls is set to 10 by default.
|
||||
and 'maximum_remote_calls'
|
||||
in config_model.automatic_function_calling.model_fields_set
|
||||
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
||||
):
|
||||
logger.warning(
|
||||
'`automatic_function_calling.disable` is set to `True`. And'
|
||||
' `automatic_function_calling.maximum_remote_calls` is a'
|
||||
' positive number'
|
||||
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
|
||||
' Disabling automatic function calling. If you want to enable'
|
||||
' automatic function calling, please set'
|
||||
' `automatic_function_calling.disable` to `False` or leave it unset,'
|
||||
' and set `automatic_function_calling.maximum_remote_calls` to a'
|
||||
' positive integer or leave'
|
||||
' `automatic_function_calling.maximum_remote_calls` unset.'
|
||||
)
|
||||
|
||||
return config_model.automatic_function_calling.disable
|
||||
|
||||
|
||||
def get_max_remote_calls_afc(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> int:
|
||||
if not config:
|
||||
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
||||
"""Returns the remaining remote calls for automatic function calling."""
|
||||
if should_disable_afc(config):
|
||||
raise ValueError(
|
||||
'automatic function calling is not enabled, but SDK is trying to get'
|
||||
' max remote calls.'
|
||||
)
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if (
|
||||
not config_model.automatic_function_calling
|
||||
or config_model.automatic_function_calling.maximum_remote_calls is None
|
||||
):
|
||||
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
||||
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
||||
|
||||
|
||||
def should_append_afc_history(
|
||||
config: Optional[types.GenerateContentConfigOrDict] = None,
|
||||
) -> bool:
|
||||
if not config:
|
||||
return True
|
||||
config_model = _create_generate_content_config_model(config)
|
||||
if not config_model.automatic_function_calling:
|
||||
return True
|
||||
return not config_model.automatic_function_calling.ignore_call_history
|
||||
Reference in New Issue
Block a user