diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c9d021f..9c03b28 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -29,6 +29,7 @@ from google.genai import types from typing_extensions import override from .. import version +from ..utils.variant_utils import GoogleLLMVariant from .base_llm import BaseLlm from .base_llm_connection import BaseLlmConnection from .gemini_llm_connection import GeminiLlmConnection @@ -178,8 +179,12 @@ class Gemini(BaseLlm): ) @cached_property - def _api_backend(self) -> str: - return 'vertex' if self.api_client.vertexai else 'ml_dev' + def _api_backend(self) -> GoogleLLMVariant: + return ( + GoogleLLMVariant.VERTEX_AI + if self.api_client.vertexai + else GoogleLLMVariant.GEMINI_API + ) @cached_property def _tracking_headers(self) -> dict[str, str]: @@ -196,7 +201,7 @@ class Gemini(BaseLlm): @cached_property def _live_api_client(self) -> Client: - if self._api_backend == 'vertex': + if self._api_backend == GoogleLLMVariant.VERTEX_AI: # use beta version for vertex api api_version = 'v1beta1' # use default api version for vertex @@ -206,7 +211,7 @@ class Gemini(BaseLlm): ) ) else: - # use v1alpha for ml_dev + # use v1alpha for using API KEY from Google AI Studio api_version = 'v1alpha' return Client( http_options=types.HttpOptions( @@ -239,7 +244,7 @@ class Gemini(BaseLlm): def _preprocess_request(self, llm_request: LlmRequest) -> None: - if llm_request.config and self._api_backend == 'ml_dev': + if llm_request.config and self._api_backend == GoogleLLMVariant.GEMINI_API: # Using API key from Google AI Studio to call model doesn't support labels. llm_request.config.labels = None diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 6bd117e..97d89cb 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -31,6 +31,7 @@ from pydantic import create_model from pydantic import fields as pydantic_fields from . import _function_parameter_parse_util +from ..utils.variant_utils import GoogleLLMVariant _py_type_2_schema_type = { 'str': types.Type.STRING, @@ -194,7 +195,7 @@ def _get_return_type(func: Callable) -> Any: def build_function_declaration( func: Union[Callable, BaseModel], ignore_params: Optional[list[str]] = None, - variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI', + variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, ) -> types.FunctionDeclaration: signature = inspect.signature(func) should_update_signature = False @@ -291,16 +292,9 @@ def build_function_declaration_util( def from_function_with_options( func: Callable, - variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI', + variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, ) -> 'types.FunctionDeclaration': - supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] - if variant not in supported_variants: - raise ValueError( - f'Unsupported variant: {variant}. Supported variants are:' - f' {", ".join(supported_variants)}' - ) - parameters_properties = {} for name, param in inspect.signature(func).parameters.items(): if param.kind in ( @@ -330,7 +324,7 @@ def from_function_with_options( declaration.parameters ) ) - if not variant == 'VERTEX_AI': + if variant == GoogleLLMVariant.GEMINI_API: return declaration return_annotation = inspect.signature(func).return_annotation diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index 964a615..1c5ed8d 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -28,6 +28,8 @@ from typing import Union from google.genai import types import pydantic +from ..utils.variant_utils import GoogleLLMVariant + _py_builtin_type_to_schema_type = { str: types.Type.STRING, int: types.Type.INTEGER, @@ -63,8 +65,10 @@ def _update_for_default_if_mldev(schema: types.Schema): ) -def _raise_if_schema_unsupported(variant: str, schema: types.Schema): - if not variant == 'VERTEX_AI': +def _raise_if_schema_unsupported( + variant: GoogleLLMVariant, schema: types.Schema +): + if variant == GoogleLLMVariant.GEMINI_API: _raise_for_any_of_if_mldev(schema) _update_for_default_if_mldev(schema) @@ -116,7 +120,7 @@ def _is_default_value_compatible( def _parse_schema_from_parameter( - variant: str, param: inspect.Parameter, func_name: str + variant: GoogleLLMVariant, param: inspect.Parameter, func_name: str ) -> types.Schema: """parse schema from parameter. diff --git a/src/google/adk/tools/base_tool.py b/src/google/adk/tools/base_tool.py index 60eea5a..ad698db 100644 --- a/src/google/adk/tools/base_tool.py +++ b/src/google/adk/tools/base_tool.py @@ -15,13 +15,14 @@ from __future__ import annotations from abc import ABC -import os from typing import Any from typing import Optional from typing import TYPE_CHECKING from google.genai import types +from ..utils.variant_utils import get_google_llm_variant +from ..utils.variant_utils import GoogleLLMVariant from .tool_context import ToolContext if TYPE_CHECKING: @@ -118,12 +119,8 @@ class BaseTool(ABC): ) @property - def _api_variant(self) -> str: - use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [ - 'true', - '1', - ] - return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI' + def _api_variant(self) -> GoogleLLMVariant: + return get_google_llm_variant() def _find_tool_with_function_declarations( diff --git a/src/google/adk/utils/variant_utils.py b/src/google/adk/utils/variant_utils.py new file mode 100644 index 0000000..0eef616 --- /dev/null +++ b/src/google/adk/utils/variant_utils.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Google LLM variants. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from __future__ import annotations + +from enum import Enum +import os + +_GOOGLE_LLM_VARIANT_VERTEX_AI = 'VERTEX_AI' +_GOOGLE_LLM_VARIANT_GEMINI_API = 'GEMINI_API' + + +class GoogleLLMVariant(Enum): + """ + The Google LLM variant to use. + see https://google.github.io/adk-docs/get-started/quickstart/#set-up-the-model + """ + + VERTEX_AI = _GOOGLE_LLM_VARIANT_VERTEX_AI + """For using credentials from Google Vertex AI""" + GEMINI_API = _GOOGLE_LLM_VARIANT_GEMINI_API + """For using API Key from Google AI Studio""" + + +def get_google_llm_variant() -> str: + return ( + GoogleLLMVariant.VERTEX_AI + if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() + in [ + 'true', + '1', + ] + else GoogleLLMVariant.GEMINI_API + ) diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index 508608c..eb95a6e 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -17,22 +17,9 @@ from typing import List from google.adk.tools import _automatic_function_calling_util from google.adk.tools.agent_tool import ToolContext -from google.adk.tools.langchain_tool import LangchainTool # TODO: crewai requires python 3.10 as minimum # from crewai_tools import FileReadTool -from langchain_community.tools import ShellTool from pydantic import BaseModel -import pytest - - -def test_unsupported_variant(): - def simple_function(input_str: str) -> str: - return {'result': input_str} - - with pytest.raises(ValueError): - _automatic_function_calling_util.build_function_declaration( - func=simple_function, variant='Unsupported' - ) def test_string_input():