mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-05 06:16:24 -06:00
fix: fix parameter schema generation for gemini
this fixes https://github.com/google/adk-python/issues/1055 and https://github.com/google/adk-python/issues/881 PiperOrigin-RevId: 766288394
This commit is contained in:
committed by
Copybara-Service
parent
f7cb66620b
commit
5a67a946d2
@@ -20,7 +20,6 @@ from .operation_parser import OperationParser
|
||||
from .rest_api_tool import AuthPreparationState
|
||||
from .rest_api_tool import RestApiTool
|
||||
from .rest_api_tool import snake_to_lower_camel
|
||||
from .rest_api_tool import to_gemini_schema
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
__all__ = [
|
||||
@@ -30,7 +29,6 @@ __all__ = [
|
||||
'OpenAPIToolset',
|
||||
'OperationParser',
|
||||
'RestApiTool',
|
||||
'to_gemini_schema',
|
||||
'snake_to_lower_camel',
|
||||
'AuthPreparationState',
|
||||
'ToolAuthHandler',
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@@ -23,8 +25,8 @@ from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ..._gemini_schema_util import _to_snake_case
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .operation_parser import OperationParser
|
||||
|
||||
|
||||
@@ -112,7 +114,7 @@ class OpenApiSpecParser:
|
||||
# If operation ID is missing, assign an operation id based on path
|
||||
# and method
|
||||
if "operationId" not in operation_dict:
|
||||
temp_id = to_snake_case(f"{path}_{method}")
|
||||
temp_id = _to_snake_case(f"{path}_{method}")
|
||||
operation_dict["operationId"] = temp_id
|
||||
|
||||
url = OperationEndpoint(base_url=base_url, path=path, method=method)
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
@@ -25,9 +27,9 @@ from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import Schema
|
||||
|
||||
from ..._gemini_schema_util import _to_snake_case
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import PydocHelper
|
||||
from ..common.common import to_snake_case
|
||||
|
||||
|
||||
class OperationParser:
|
||||
@@ -189,7 +191,7 @@ class OperationParser:
|
||||
operation_id = self._operation.operationId
|
||||
if not operation_id:
|
||||
raise ValueError('Operation ID is missing')
|
||||
return to_snake_case(operation_id)[:60]
|
||||
return _to_snake_case(operation_id)[:60]
|
||||
|
||||
def get_return_type_hint(self) -> str:
|
||||
"""Returns the return type hint string (like 'str', 'int', etc.)."""
|
||||
|
||||
@@ -12,45 +12,36 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....tools.base_tool import BaseTool
|
||||
from ..._gemini_schema_util import _to_gemini_schema
|
||||
from ..._gemini_schema_util import _to_snake_case
|
||||
from ...base_tool import BaseTool
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.auth_helpers import credential_to_param
|
||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .openapi_spec_parser import OperationEndpoint
|
||||
from .openapi_spec_parser import ParsedOperation
|
||||
from .operation_parser import OperationParser
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
# Not supported by the Gemini API
|
||||
_OPENAPI_SCHEMA_IGNORE_FIELDS = (
|
||||
"title",
|
||||
"default",
|
||||
"format",
|
||||
"additional_properties",
|
||||
"ref",
|
||||
"def",
|
||||
)
|
||||
|
||||
|
||||
def snake_to_lower_camel(snake_case_string: str):
|
||||
"""Converts a snake_case string to a lower_camel_case string.
|
||||
@@ -70,117 +61,6 @@ def snake_to_lower_camel(snake_case_string: str):
|
||||
])
|
||||
|
||||
|
||||
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||
# in Gemini SDK.
|
||||
def normalize_json_schema_type(
|
||||
json_schema_type: Optional[Union[str, Sequence[str]]],
|
||||
) -> tuple[Optional[str], bool]:
|
||||
"""Converts a JSON Schema Type into Gemini Schema type.
|
||||
|
||||
Adopted and modified from Gemini SDK. This gets the first available schema
|
||||
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
|
||||
contains a list of types, the first non null type is used.
|
||||
|
||||
Remove this after switching to Gemini `from_json_schema`.
|
||||
"""
|
||||
if json_schema_type is None:
|
||||
return None, False
|
||||
if isinstance(json_schema_type, str):
|
||||
if json_schema_type == "null":
|
||||
return None, True
|
||||
return json_schema_type, False
|
||||
|
||||
non_null_types = []
|
||||
nullable = False
|
||||
# If json schema type is an array, pick the first non null type.
|
||||
for type_value in json_schema_type:
|
||||
if type_value == "null":
|
||||
nullable = True
|
||||
else:
|
||||
non_null_types.append(type_value)
|
||||
non_null_type = non_null_types[0] if non_null_types else None
|
||||
return non_null_type, nullable
|
||||
|
||||
|
||||
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||
# in Gemini SDK.
|
||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dictionary.
|
||||
|
||||
Returns:
|
||||
A Pydantic Schema object. Returns None if input is None.
|
||||
Raises TypeError if input is not a dict.
|
||||
"""
|
||||
if openapi_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(openapi_schema, dict):
|
||||
raise TypeError("openapi_schema must be a dictionary")
|
||||
|
||||
pydantic_schema_data = {}
|
||||
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
if not openapi_schema.get("type"):
|
||||
openapi_schema["type"] = "object"
|
||||
|
||||
for key, value in openapi_schema.items():
|
||||
snake_case_key = to_snake_case(key)
|
||||
# Check if the snake_case_key exists in the Schema model's fields.
|
||||
if snake_case_key in Schema.model_fields:
|
||||
if snake_case_key in _OPENAPI_SCHEMA_IGNORE_FIELDS:
|
||||
# Ignore these fields as Gemini backend doesn't recognize them, and will
|
||||
# throw exception if they appear in the schema.
|
||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||
# supported for STRING type
|
||||
continue
|
||||
elif snake_case_key == "type":
|
||||
schema_type, nullable = normalize_json_schema_type(
|
||||
openapi_schema.get("type", None)
|
||||
)
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
pydantic_schema_data["type"] = schema_type if schema_type else "object"
|
||||
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
|
||||
if nullable:
|
||||
pydantic_schema_data["nullable"] = True
|
||||
elif snake_case_key == "properties" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = {
|
||||
k: to_gemini_schema(v) for k, v in value.items()
|
||||
}
|
||||
elif snake_case_key == "items" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
elif snake_case_key == "any_of" and isinstance(value, list):
|
||||
pydantic_schema_data[snake_case_key] = [
|
||||
to_gemini_schema(item) for item in value
|
||||
]
|
||||
# Important: Handle cases where the OpenAPI schema might contain lists
|
||||
# or other structures that need to be recursively processed.
|
||||
elif isinstance(value, list) and snake_case_key not in (
|
||||
"enum",
|
||||
"required",
|
||||
"property_ordering",
|
||||
):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_list.append(to_gemini_schema(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
pydantic_schema_data[snake_case_key] = new_list
|
||||
elif isinstance(value, dict) and snake_case_key not in ("properties"):
|
||||
# Handle dictionary which is neither properties or items
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
else:
|
||||
# Simple value assignment (int, str, bool, etc.)
|
||||
pydantic_schema_data[snake_case_key] = value
|
||||
|
||||
return Schema(**pydantic_schema_data)
|
||||
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
@@ -273,7 +153,7 @@ class RestApiTool(BaseTool):
|
||||
parsed.operation, parsed.parameters, parsed.return_value
|
||||
)
|
||||
|
||||
tool_name = to_snake_case(operation_parser.get_function_name())
|
||||
tool_name = _to_snake_case(operation_parser.get_function_name())
|
||||
generated = cls(
|
||||
name=tool_name,
|
||||
description=parsed.operation.description
|
||||
@@ -306,7 +186,7 @@ class RestApiTool(BaseTool):
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._operation_parser.get_json_schema()
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
parameters = _to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user