refactor: uniform Google LLM variant and parsing logic and make contant value consistent with Google GenAI SDK : 903e0729ce/google/genai/_automatic_function_calling_util.py (L96)

PiperOrigin-RevId: 765639681
This commit is contained in:
Xiang (Sean) Zhou 2025-05-31 13:11:44 -07:00 committed by Copybara-Service
parent 62d7bf58bb
commit 036f954a2a
6 changed files with 76 additions and 38 deletions

View File

@ -29,6 +29,7 @@ from google.genai import types
from typing_extensions import override from typing_extensions import override
from .. import version from .. import version
from ..utils.variant_utils import GoogleLLMVariant
from .base_llm import BaseLlm from .base_llm import BaseLlm
from .base_llm_connection import BaseLlmConnection from .base_llm_connection import BaseLlmConnection
from .gemini_llm_connection import GeminiLlmConnection from .gemini_llm_connection import GeminiLlmConnection
@ -178,8 +179,12 @@ class Gemini(BaseLlm):
) )
@cached_property @cached_property
def _api_backend(self) -> str: def _api_backend(self) -> GoogleLLMVariant:
return 'vertex' if self.api_client.vertexai else 'ml_dev' return (
GoogleLLMVariant.VERTEX_AI
if self.api_client.vertexai
else GoogleLLMVariant.GEMINI_API
)
@cached_property @cached_property
def _tracking_headers(self) -> dict[str, str]: def _tracking_headers(self) -> dict[str, str]:
@ -196,7 +201,7 @@ class Gemini(BaseLlm):
@cached_property @cached_property
def _live_api_client(self) -> Client: def _live_api_client(self) -> Client:
if self._api_backend == 'vertex': if self._api_backend == GoogleLLMVariant.VERTEX_AI:
# use beta version for vertex api # use beta version for vertex api
api_version = 'v1beta1' api_version = 'v1beta1'
# use default api version for vertex # use default api version for vertex
@ -206,7 +211,7 @@ class Gemini(BaseLlm):
) )
) )
else: else:
# use v1alpha for ml_dev # use v1alpha for using API KEY from Google AI Studio
api_version = 'v1alpha' api_version = 'v1alpha'
return Client( return Client(
http_options=types.HttpOptions( http_options=types.HttpOptions(
@ -239,7 +244,7 @@ class Gemini(BaseLlm):
def _preprocess_request(self, llm_request: LlmRequest) -> None: def _preprocess_request(self, llm_request: LlmRequest) -> None:
if llm_request.config and self._api_backend == 'ml_dev': if llm_request.config and self._api_backend == GoogleLLMVariant.GEMINI_API:
# Using API key from Google AI Studio to call model doesn't support labels. # Using API key from Google AI Studio to call model doesn't support labels.
llm_request.config.labels = None llm_request.config.labels = None

View File

@ -31,6 +31,7 @@ from pydantic import create_model
from pydantic import fields as pydantic_fields from pydantic import fields as pydantic_fields
from . import _function_parameter_parse_util from . import _function_parameter_parse_util
from ..utils.variant_utils import GoogleLLMVariant
_py_type_2_schema_type = { _py_type_2_schema_type = {
'str': types.Type.STRING, 'str': types.Type.STRING,
@ -194,7 +195,7 @@ def _get_return_type(func: Callable) -> Any:
def build_function_declaration( def build_function_declaration(
func: Union[Callable, BaseModel], func: Union[Callable, BaseModel],
ignore_params: Optional[list[str]] = None, ignore_params: Optional[list[str]] = None,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI', variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
) -> types.FunctionDeclaration: ) -> types.FunctionDeclaration:
signature = inspect.signature(func) signature = inspect.signature(func)
should_update_signature = False should_update_signature = False
@ -291,16 +292,9 @@ def build_function_declaration_util(
def from_function_with_options( def from_function_with_options(
func: Callable, func: Callable,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI', variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
) -> 'types.FunctionDeclaration': ) -> 'types.FunctionDeclaration':
supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
if variant not in supported_variants:
raise ValueError(
f'Unsupported variant: {variant}. Supported variants are:'
f' {", ".join(supported_variants)}'
)
parameters_properties = {} parameters_properties = {}
for name, param in inspect.signature(func).parameters.items(): for name, param in inspect.signature(func).parameters.items():
if param.kind in ( if param.kind in (
@ -330,7 +324,7 @@ def from_function_with_options(
declaration.parameters declaration.parameters
) )
) )
if not variant == 'VERTEX_AI': if variant == GoogleLLMVariant.GEMINI_API:
return declaration return declaration
return_annotation = inspect.signature(func).return_annotation return_annotation = inspect.signature(func).return_annotation

View File

@ -28,6 +28,8 @@ from typing import Union
from google.genai import types from google.genai import types
import pydantic import pydantic
from ..utils.variant_utils import GoogleLLMVariant
_py_builtin_type_to_schema_type = { _py_builtin_type_to_schema_type = {
str: types.Type.STRING, str: types.Type.STRING,
int: types.Type.INTEGER, int: types.Type.INTEGER,
@ -63,8 +65,10 @@ def _update_for_default_if_mldev(schema: types.Schema):
) )
def _raise_if_schema_unsupported(variant: str, schema: types.Schema): def _raise_if_schema_unsupported(
if not variant == 'VERTEX_AI': variant: GoogleLLMVariant, schema: types.Schema
):
if variant == GoogleLLMVariant.GEMINI_API:
_raise_for_any_of_if_mldev(schema) _raise_for_any_of_if_mldev(schema)
_update_for_default_if_mldev(schema) _update_for_default_if_mldev(schema)
@ -116,7 +120,7 @@ def _is_default_value_compatible(
def _parse_schema_from_parameter( def _parse_schema_from_parameter(
variant: str, param: inspect.Parameter, func_name: str variant: GoogleLLMVariant, param: inspect.Parameter, func_name: str
) -> types.Schema: ) -> types.Schema:
"""parse schema from parameter. """parse schema from parameter.

View File

@ -15,13 +15,14 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
import os
from typing import Any from typing import Any
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from google.genai import types from google.genai import types
from ..utils.variant_utils import get_google_llm_variant
from ..utils.variant_utils import GoogleLLMVariant
from .tool_context import ToolContext from .tool_context import ToolContext
if TYPE_CHECKING: if TYPE_CHECKING:
@ -118,12 +119,8 @@ class BaseTool(ABC):
) )
@property @property
def _api_variant(self) -> str: def _api_variant(self) -> GoogleLLMVariant:
use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [ return get_google_llm_variant()
'true',
'1',
]
return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI'
def _find_tool_with_function_declarations( def _find_tool_with_function_declarations(

View File

@ -0,0 +1,51 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Google LLM variants.
This module is for ADK internal use only.
Please do not rely on the implementation details.
"""
from __future__ import annotations
from enum import Enum
import os
_GOOGLE_LLM_VARIANT_VERTEX_AI = 'VERTEX_AI'
_GOOGLE_LLM_VARIANT_GEMINI_API = 'GEMINI_API'
class GoogleLLMVariant(Enum):
"""
The Google LLM variant to use.
see https://google.github.io/adk-docs/get-started/quickstart/#set-up-the-model
"""
VERTEX_AI = _GOOGLE_LLM_VARIANT_VERTEX_AI
"""For using credentials from Google Vertex AI"""
GEMINI_API = _GOOGLE_LLM_VARIANT_GEMINI_API
"""For using API Key from Google AI Studio"""
def get_google_llm_variant() -> str:
return (
GoogleLLMVariant.VERTEX_AI
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower()
in [
'true',
'1',
]
else GoogleLLMVariant.GEMINI_API
)

View File

@ -17,22 +17,9 @@ from typing import List
from google.adk.tools import _automatic_function_calling_util from google.adk.tools import _automatic_function_calling_util
from google.adk.tools.agent_tool import ToolContext from google.adk.tools.agent_tool import ToolContext
from google.adk.tools.langchain_tool import LangchainTool
# TODO: crewai requires python 3.10 as minimum # TODO: crewai requires python 3.10 as minimum
# from crewai_tools import FileReadTool # from crewai_tools import FileReadTool
from langchain_community.tools import ShellTool
from pydantic import BaseModel from pydantic import BaseModel
import pytest
def test_unsupported_variant():
def simple_function(input_str: str) -> str:
return {'result': input_str}
with pytest.raises(ValueError):
_automatic_function_calling_util.build_function_declaration(
func=simple_function, variant='Unsupported'
)
def test_string_input(): def test_string_input():