fix: fix parameter schema generation for gemini

this fixes https://github.com/google/adk-python/issues/1055
and https://github.com/google/adk-python/issues/881

PiperOrigin-RevId: 766288394
This commit is contained in:
Xiang (Sean) Zhou
2025-06-02 12:02:26 -07:00
committed by Copybara-Service
parent f7cb66620b
commit 5a67a946d2
12 changed files with 582 additions and 457 deletions
@@ -20,7 +20,6 @@ from .operation_parser import OperationParser
from .rest_api_tool import AuthPreparationState
from .rest_api_tool import RestApiTool
from .rest_api_tool import snake_to_lower_camel
from .rest_api_tool import to_gemini_schema
from .tool_auth_handler import ToolAuthHandler
__all__ = [
@@ -30,7 +29,6 @@ __all__ = [
'OpenAPIToolset',
'OperationParser',
'RestApiTool',
'to_gemini_schema',
'snake_to_lower_camel',
'AuthPreparationState',
'ToolAuthHandler',
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
from typing import Any
from typing import Dict
@@ -23,8 +25,8 @@ from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ..._gemini_schema_util import _to_snake_case
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .operation_parser import OperationParser
@@ -112,7 +114,7 @@ class OpenApiSpecParser:
# If operation ID is missing, assign an operation id based on path
# and method
if "operationId" not in operation_dict:
temp_id = to_snake_case(f"{path}_{method}")
temp_id = _to_snake_case(f"{path}_{method}")
operation_dict["operationId"] = temp_id
url = OperationEndpoint(base_url=base_url, path=path, method=method)
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
from textwrap import dedent
from typing import Any
@@ -25,9 +27,9 @@ from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from ..._gemini_schema_util import _to_snake_case
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
class OperationParser:
@@ -189,7 +191,7 @@ class OperationParser:
operation_id = self._operation.operationId
if not operation_id:
raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60]
return _to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.)."""
@@ -12,45 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from fastapi.openapi.models import Operation
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
import requests
from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools.base_tool import BaseTool
from ..._gemini_schema_util import _to_gemini_schema
from ..._gemini_schema_util import _to_snake_case
from ...base_tool import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .openapi_spec_parser import OperationEndpoint
from .openapi_spec_parser import ParsedOperation
from .operation_parser import OperationParser
from .tool_auth_handler import ToolAuthHandler
# Not supported by the Gemini API
_OPENAPI_SCHEMA_IGNORE_FIELDS = (
"title",
"default",
"format",
"additional_properties",
"ref",
"def",
)
def snake_to_lower_camel(snake_case_string: str):
"""Converts a snake_case string to a lower_camel_case string.
@@ -70,117 +61,6 @@ def snake_to_lower_camel(snake_case_string: str):
])
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def normalize_json_schema_type(
json_schema_type: Optional[Union[str, Sequence[str]]],
) -> tuple[Optional[str], bool]:
"""Converts a JSON Schema Type into Gemini Schema type.
Adopted and modified from Gemini SDK. This gets the first available schema
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
contains a list of types, the first non null type is used.
Remove this after switching to Gemini `from_json_schema`.
"""
if json_schema_type is None:
return None, False
if isinstance(json_schema_type, str):
if json_schema_type == "null":
return None, True
return json_schema_type, False
non_null_types = []
nullable = False
# If json schema type is an array, pick the first non null type.
for type_value in json_schema_type:
if type_value == "null":
nullable = True
else:
non_null_types.append(type_value)
non_null_type = non_null_types[0] if non_null_types else None
return non_null_type, nullable
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
Args:
openapi_schema: The OpenAPI schema dictionary.
Returns:
A Pydantic Schema object. Returns None if input is None.
Raises TypeError if input is not a dict.
"""
if openapi_schema is None:
return None
if not isinstance(openapi_schema, dict):
raise TypeError("openapi_schema must be a dictionary")
pydantic_schema_data = {}
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
if snake_case_key in Schema.model_fields:
if snake_case_key in _OPENAPI_SCHEMA_IGNORE_FIELDS:
# Ignore these fields as Gemini backend doesn't recognize them, and will
# throw exception if they appear in the schema.
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
elif snake_case_key == "type":
schema_type, nullable = normalize_json_schema_type(
openapi_schema.get("type", None)
)
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
pydantic_schema_data["type"] = schema_type if schema_type else "object"
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
if nullable:
pydantic_schema_data["nullable"] = True
elif snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}
elif snake_case_key == "items" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
elif snake_case_key == "any_of" and isinstance(value, list):
pydantic_schema_data[snake_case_key] = [
to_gemini_schema(item) for item in value
]
# Important: Handle cases where the OpenAPI schema might contain lists
# or other structures that need to be recursively processed.
elif isinstance(value, list) and snake_case_key not in (
"enum",
"required",
"property_ordering",
):
new_list = []
for item in value:
if isinstance(item, dict):
new_list.append(to_gemini_schema(item))
else:
new_list.append(item)
pydantic_schema_data[snake_case_key] = new_list
elif isinstance(value, dict) and snake_case_key not in ("properties"):
# Handle dictionary which is neither properties or items
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
else:
# Simple value assignment (int, str, bool, etc.)
pydantic_schema_data[snake_case_key] = value
return Schema(**pydantic_schema_data)
AuthPreparationState = Literal["pending", "done"]
@@ -273,7 +153,7 @@ class RestApiTool(BaseTool):
parsed.operation, parsed.parameters, parsed.return_value
)
tool_name = to_snake_case(operation_parser.get_function_name())
tool_name = _to_snake_case(operation_parser.get_function_name())
generated = cls(
name=tool_name,
description=parsed.operation.description
@@ -306,7 +186,7 @@ class RestApiTool(BaseTool):
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = to_gemini_schema(schema_dict)
parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)