adk-python/src/google/adk/tools/_automatic_function_calling_util.py
Liang Wu 2a8ca06c3e chore: remove reference to genai SDK folder.
Added `from __future__ import annotations` to follow the best practice.

PiperOrigin-RevId: 764473253
2025-05-28 16:54:19 -07:00

346 lines
10 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
from types import FunctionType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types
import pydantic
from pydantic import BaseModel
from pydantic import create_model
from pydantic import fields as pydantic_fields
from . import function_parameter_parse_util
_py_type_2_schema_type = {
'str': types.Type.STRING,
'int': types.Type.INTEGER,
'float': types.Type.NUMBER,
'bool': types.Type.BOOLEAN,
'string': types.Type.STRING,
'integer': types.Type.INTEGER,
'number': types.Type.NUMBER,
'boolean': types.Type.BOOLEAN,
'list': types.Type.ARRAY,
'array': types.Type.ARRAY,
'tuple': types.Type.ARRAY,
'object': types.Type.OBJECT,
'Dict': types.Type.OBJECT,
'List': types.Type.ARRAY,
'Tuple': types.Type.ARRAY,
'Any': types.Type.TYPE_UNSPECIFIED,
}
def _get_fields_dict(func: Callable) -> Dict:
param_signature = dict(inspect.signature(func).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(
param.annotation
if param.annotation != inspect.Parameter.empty
else Any
),
pydantic.Field(
# 2. We do not support default values for now.
default=(
param.default
if param.default != inspect.Parameter.empty
# ! Need to use Undefined instead of None
else pydantic_fields.PydanticUndefined
),
# 3. Do not support parameter description for now.
description=None,
),
)
for name, param in param_signature.items()
# We do not support *args or **kwargs
if param.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
return fields_dict
def _annotate_nullable_fields(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
# for Optional[T], the pydantic schema is:
# {
# "type": "object",
# "properties": {
# "anyOf": [
# {
# "type": "null"
# },
# {
# "type": "T"
# }
# ]
# }
# }
for type_ in property_schema.get('anyOf', []):
if type_.get('type') == 'null':
property_schema['nullable'] = True
property_schema['anyOf'].remove(type_)
break
def _annotate_required_fields(schema: Dict):
required = [
field_name
for field_name, field_schema in schema.get('properties', {}).items()
if not field_schema.get('nullable') and 'default' not in field_schema
]
schema['required'] = required
def _remove_any_of(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
union_types = property_schema.pop('anyOf', None)
# Take the first non-null type.
if union_types:
for type_ in union_types:
if type_.get('type') != 'null':
property_schema.update(type_)
def _remove_default(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('default', None)
def _remove_nullable(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('nullable', None)
def _remove_title(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('title', None)
def _get_pydantic_schema(func: Callable) -> Dict:
fields_dict = _get_fields_dict(func)
if 'tool_context' in fields_dict.keys():
fields_dict.pop('tool_context')
return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
_annotate_nullable_fields(schema)
_annotate_required_fields(schema)
if not vertexai:
_remove_any_of(schema)
_remove_default(schema)
_remove_nullable(schema)
_remove_title(schema)
return schema
def _map_pydantic_type_to_property_schema(property_schema: Dict):
if 'type' in property_schema:
property_schema['type'] = _py_type_2_schema_type.get(
property_schema['type'], 'TYPE_UNSPECIFIED'
)
if property_schema['type'] == 'ARRAY':
_map_pydantic_type_to_property_schema(property_schema['items'])
for type_ in property_schema.get('anyOf', []):
if 'type' in type_:
type_['type'] = _py_type_2_schema_type.get(
type_['type'], 'TYPE_UNSPECIFIED'
)
# TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
# avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
property_schema['type'] = type_['type']
def _map_pydantic_type_to_schema_type(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
_map_pydantic_type_to_property_schema(property_schema)
def _get_return_type(func: Callable) -> Any:
return _py_type_2_schema_type.get(
inspect.signature(func).return_annotation.__name__,
inspect.signature(func).return_annotation.__name__,
)
def build_function_declaration(
func: Union[Callable, BaseModel],
ignore_params: Optional[list[str]] = None,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
) -> types.FunctionDeclaration:
signature = inspect.signature(func)
should_update_signature = False
new_func = None
if not ignore_params:
ignore_params = []
for name, _ in signature.parameters.items():
if name in ignore_params:
should_update_signature = True
break
if should_update_signature:
new_params = [
param
for name, param in signature.parameters.items()
if name not in ignore_params
]
if isinstance(func, type):
fields = {
name: (param.annotation, param.default)
for name, param in signature.parameters.items()
if name not in ignore_params
}
new_func = create_model(func.__name__, **fields)
else:
new_sig = signature.replace(parameters=new_params)
new_func = FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
new_func.__signature__ = new_sig
return (
from_function_with_options(func, variant)
if not should_update_signature
else from_function_with_options(new_func, variant)
)
def build_function_declaration_for_langchain(
vertexai: bool, name, description, func, param_pydantic_schema
) -> types.FunctionDeclaration:
param_pydantic_schema = _process_pydantic_schema(
vertexai, {'properties': param_pydantic_schema}
)['properties']
param_copy = param_pydantic_schema.copy()
required_fields = param_copy.pop('required', [])
before_param_pydantic_schema = {
'properties': param_copy,
'required': required_fields,
}
return build_function_declaration_util(
vertexai, name, description, func, before_param_pydantic_schema
)
def build_function_declaration_for_params_for_crewai(
vertexai: bool, name, description, func, param_pydantic_schema
) -> types.FunctionDeclaration:
param_pydantic_schema = _process_pydantic_schema(
vertexai, param_pydantic_schema
)
param_copy = param_pydantic_schema.copy()
return build_function_declaration_util(
vertexai, name, description, func, param_copy
)
def build_function_declaration_util(
vertexai: bool, name, description, func, before_param_pydantic_schema
) -> types.FunctionDeclaration:
_map_pydantic_type_to_schema_type(before_param_pydantic_schema)
properties = before_param_pydantic_schema.get('properties', {})
function_declaration = types.FunctionDeclaration(
parameters=types.Schema(
type='OBJECT',
properties=properties,
)
if properties
else None,
description=description,
name=name,
)
if vertexai and isinstance(func, Callable):
return_pydantic_schema = _get_return_type(func)
function_declaration.response = types.Schema(
type=return_pydantic_schema,
)
return function_declaration
def from_function_with_options(
func: Callable,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
) -> 'types.FunctionDeclaration':
supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
if variant not in supported_variants:
raise ValueError(
f'Unsupported variant: {variant}. Supported variants are:'
f' {", ".join(supported_variants)}'
)
parameters_properties = {}
for name, param in inspect.signature(func).parameters.items():
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
):
schema = function_parameter_parse_util._parse_schema_from_parameter(
variant, param, func.__name__
)
parameters_properties[name] = schema
declaration = types.FunctionDeclaration(
name=func.__name__,
description=func.__doc__,
)
if parameters_properties:
declaration.parameters = types.Schema(
type='OBJECT',
properties=parameters_properties,
)
declaration.parameters.required = (
function_parameter_parse_util._get_required_fields(
declaration.parameters
)
)
if not variant == 'VERTEX_AI':
return declaration
return_annotation = inspect.signature(func).return_annotation
if return_annotation is inspect._empty:
return declaration
declaration.response = (
function_parameter_parse_util._parse_schema_from_parameter(
variant,
inspect.Parameter(
'return_value',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=return_annotation,
),
func.__name__,
)
)
return declaration