Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,51 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-bad-import-order
from .base_tool import BaseTool
from ..auth.auth_tool import AuthToolArguments
from .apihub_tool.apihub_toolset import APIHubToolset
from .built_in_code_execution_tool import built_in_code_execution
from .google_search_tool import google_search
from .vertex_ai_search_tool import VertexAiSearchTool
from .example_tool import ExampleTool
from .exit_loop_tool import exit_loop
from .function_tool import FunctionTool
from .get_user_choice_tool import get_user_choice_tool as get_user_choice
from .load_artifacts_tool import load_artifacts_tool as load_artifacts
from .load_memory_tool import load_memory_tool as load_memory
from .long_running_tool import LongRunningFunctionTool
from .preload_memory_tool import preload_memory_tool as preload_memory
from .tool_context import ToolContext
from .transfer_to_agent_tool import transfer_to_agent
__all__ = [
'APIHubToolset',
'AuthToolArguments',
'BaseTool',
'built_in_code_execution',
'google_search',
'VertexAiSearchTool',
'ExampleTool',
'exit_loop',
'FunctionTool',
'get_user_choice',
'load_artifacts',
'load_memory',
'LongRunningFunctionTool',
'preload_memory',
'ToolContext',
'transfer_to_agent',
]

View File

@@ -0,0 +1,346 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forked from google3/third_party/py/google/genai/_automatic_function_calling_util.py temporarily."""
import inspect
from types import FunctionType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types
import pydantic
from pydantic import BaseModel
from pydantic import create_model
from pydantic import fields as pydantic_fields
from . import function_parameter_parse_util
_py_type_2_schema_type = {
'str': types.Type.STRING,
'int': types.Type.INTEGER,
'float': types.Type.NUMBER,
'bool': types.Type.BOOLEAN,
'string': types.Type.STRING,
'integer': types.Type.INTEGER,
'number': types.Type.NUMBER,
'boolean': types.Type.BOOLEAN,
'list': types.Type.ARRAY,
'array': types.Type.ARRAY,
'tuple': types.Type.ARRAY,
'object': types.Type.OBJECT,
'Dict': types.Type.OBJECT,
'List': types.Type.ARRAY,
'Tuple': types.Type.ARRAY,
'Any': types.Type.TYPE_UNSPECIFIED,
}
def _get_fields_dict(func: Callable) -> Dict:
param_signature = dict(inspect.signature(func).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(
param.annotation
if param.annotation != inspect.Parameter.empty
else Any
),
pydantic.Field(
# 2. We do not support default values for now.
default=(
param.default
if param.default != inspect.Parameter.empty
# ! Need to use Undefined instead of None
else pydantic_fields.PydanticUndefined
),
# 3. Do not support parameter description for now.
description=None,
),
)
for name, param in param_signature.items()
# We do not support *args or **kwargs
if param.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
return fields_dict
def _annotate_nullable_fields(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
# for Optional[T], the pydantic schema is:
# {
# "type": "object",
# "properties": {
# "anyOf": [
# {
# "type": "null"
# },
# {
# "type": "T"
# }
# ]
# }
# }
for type_ in property_schema.get('anyOf', []):
if type_.get('type') == 'null':
property_schema['nullable'] = True
property_schema['anyOf'].remove(type_)
break
def _annotate_required_fields(schema: Dict):
required = [
field_name
for field_name, field_schema in schema.get('properties', {}).items()
if not field_schema.get('nullable') and 'default' not in field_schema
]
schema['required'] = required
def _remove_any_of(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
union_types = property_schema.pop('anyOf', None)
# Take the first non-null type.
if union_types:
for type_ in union_types:
if type_.get('type') != 'null':
property_schema.update(type_)
def _remove_default(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('default', None)
def _remove_nullable(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('nullable', None)
def _remove_title(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
property_schema.pop('title', None)
def _get_pydantic_schema(func: Callable) -> Dict:
fields_dict = _get_fields_dict(func)
if 'tool_context' in fields_dict.keys():
fields_dict.pop('tool_context')
return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
_annotate_nullable_fields(schema)
_annotate_required_fields(schema)
if not vertexai:
_remove_any_of(schema)
_remove_default(schema)
_remove_nullable(schema)
_remove_title(schema)
return schema
def _map_pydantic_type_to_property_schema(property_schema: Dict):
if 'type' in property_schema:
property_schema['type'] = _py_type_2_schema_type.get(
property_schema['type'], 'TYPE_UNSPECIFIED'
)
if property_schema['type'] == 'ARRAY':
_map_pydantic_type_to_property_schema(property_schema['items'])
for type_ in property_schema.get('anyOf', []):
if 'type' in type_:
type_['type'] = _py_type_2_schema_type.get(
type_['type'], 'TYPE_UNSPECIFIED'
)
# TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
# avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
property_schema['type'] = type_['type']
def _map_pydantic_type_to_schema_type(schema: Dict):
for _, property_schema in schema.get('properties', {}).items():
_map_pydantic_type_to_property_schema(property_schema)
def _get_return_type(func: Callable) -> Any:
return _py_type_2_schema_type.get(
inspect.signature(func).return_annotation.__name__,
inspect.signature(func).return_annotation.__name__,
)
def build_function_declaration(
func: Union[Callable, BaseModel],
ignore_params: Optional[list[str]] = None,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
) -> types.FunctionDeclaration:
signature = inspect.signature(func)
should_update_signature = False
new_func = None
if not ignore_params:
ignore_params = []
for name, _ in signature.parameters.items():
if name in ignore_params:
should_update_signature = True
break
if should_update_signature:
new_params = [
param
for name, param in signature.parameters.items()
if name not in ignore_params
]
if isinstance(func, type):
fields = {
name: (param.annotation, param.default)
for name, param in signature.parameters.items()
if name not in ignore_params
}
new_func = create_model(func.__name__, **fields)
else:
new_sig = signature.replace(parameters=new_params)
new_func = FunctionType(
func.__code__,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
new_func.__signature__ = new_sig
return (
from_function_with_options(func, variant)
if not should_update_signature
else from_function_with_options(new_func, variant)
)
def build_function_declaration_for_langchain(
vertexai: bool, name, description, func, param_pydantic_schema
) -> types.FunctionDeclaration:
param_pydantic_schema = _process_pydantic_schema(
vertexai, {'properties': param_pydantic_schema}
)['properties']
param_copy = param_pydantic_schema.copy()
required_fields = param_copy.pop('required', [])
before_param_pydantic_schema = {
'properties': param_copy,
'required': required_fields,
}
return build_function_declaration_util(
vertexai, name, description, func, before_param_pydantic_schema
)
def build_function_declaration_for_params_for_crewai(
vertexai: bool, name, description, func, param_pydantic_schema
) -> types.FunctionDeclaration:
param_pydantic_schema = _process_pydantic_schema(
vertexai, param_pydantic_schema
)
param_copy = param_pydantic_schema.copy()
return build_function_declaration_util(
vertexai, name, description, func, param_copy
)
def build_function_declaration_util(
vertexai: bool, name, description, func, before_param_pydantic_schema
) -> types.FunctionDeclaration:
_map_pydantic_type_to_schema_type(before_param_pydantic_schema)
properties = before_param_pydantic_schema.get('properties', {})
function_declaration = types.FunctionDeclaration(
parameters=types.Schema(
type='OBJECT',
properties=properties,
)
if properties
else None,
description=description,
name=name,
)
if vertexai and isinstance(func, Callable):
return_pydantic_schema = _get_return_type(func)
function_declaration.response = types.Schema(
type=return_pydantic_schema,
)
return function_declaration
def from_function_with_options(
func: Callable,
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
) -> 'types.FunctionDeclaration':
supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
if variant not in supported_variants:
raise ValueError(
f'Unsupported variant: {variant}. Supported variants are:'
f' {", ".join(supported_variants)}'
)
parameters_properties = {}
for name, param in inspect.signature(func).parameters.items():
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
):
schema = function_parameter_parse_util._parse_schema_from_parameter(
variant, param, func.__name__
)
parameters_properties[name] = schema
declaration = types.FunctionDeclaration(
name=func.__name__,
description=func.__doc__,
)
if parameters_properties:
declaration.parameters = types.Schema(
type='OBJECT',
properties=parameters_properties,
)
if variant == 'VERTEX_AI':
declaration.parameters.required = (
function_parameter_parse_util._get_required_fields(
declaration.parameters
)
)
if not variant == 'VERTEX_AI':
return declaration
return_annotation = inspect.signature(func).return_annotation
if return_annotation is inspect._empty:
return declaration
declaration.response = (
function_parameter_parse_util._parse_schema_from_parameter(
variant,
inspect.Parameter(
'return_value',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=return_annotation,
),
func.__name__,
)
)
return declaration

View File

@@ -0,0 +1,176 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from pydantic import model_validator
from typing_extensions import override
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from . import _automatic_function_calling_util
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import LlmAgent
class AgentTool(BaseTool):
"""A tool that wraps an agent.
This tool allows an agent to be called as a tool within a larger application.
The agent's input schema is used to define the tool's input parameters, and
the agent's output is returned as the tool's result.
Attributes:
agent: The agent to wrap.
skip_summarization: Whether to skip summarization of the agent output.
"""
def __init__(self, agent: BaseAgent):
self.agent = agent
self.skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""
super().__init__(name=agent.name, description=agent.description)
@model_validator(mode='before')
@classmethod
def populate_name(cls, data: Any) -> Any:
data['name'] = data['agent'].name
return data
@override
def _get_declaration(self) -> types.FunctionDeclaration:
from ..agents.llm_agent import LlmAgent
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
result = _automatic_function_calling_util.build_function_declaration(
func=self.agent.input_schema, variant=self._api_variant
)
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'request': types.Schema(
type=types.Type.STRING,
),
},
required=['request'],
),
description=self.agent.description,
name=self.name,
)
result.name = self.name
return result
@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
from ..agents.llm_agent import LlmAgent
if self.skip_summarization:
tool_context.actions.skip_summarization = True
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
else:
input_value = args['request']
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
if isinstance(input_value, dict):
input_value = self.agent.input_schema.model_validate(input_value)
if not isinstance(input_value, self.agent.input_schema):
raise ValueError(
f'Input value {input_value} is not of type'
f' `{self.agent.input_schema}`.'
)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=input_value)],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
# TODO(kech): Remove the access to the invocation context.
# It seems we don't need re-use artifact_service if we forward below.
artifact_service=tool_context._invocation_context.artifact_service,
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
session = runner.session_service.create_session(
app_name=self.agent.name,
user_id='tmp_user',
state=tool_context.state.to_dict(),
)
last_event = None
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
last_event = event
if runner.artifact_service:
# Forward all artifacts to parent session.
for artifact_name in runner.artifact_service.list_artifact_keys(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
):
if artifact := runner.artifact_service.load_artifact(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
filename=artifact_name,
):
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
if (
not last_event
or not last_event.content
or not last_event.content.parts
or not last_event.content.parts[0].text
):
return ''
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
last_event.content.parts[0].text
).model_dump(exclude_none=True)
else:
tool_result = last_event.content.parts[0].text
return tool_result

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .apihub_toolset import APIHubToolset
__all__ = [
'APIHubToolset',
]

View File

@@ -0,0 +1,209 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional
import yaml
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ..openapi_tool.common.common import to_snake_case
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.apihub_client import APIHubClient
class APIHubToolset:
"""APIHubTool generates tools from a given API Hub resource.
Examples:
```
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
])
```
**apihub_resource_name** is the resource name from API Hub. It must include
API name, and can optionally include API version and spec name.
- If apihub_resource_name includes a spec resource name, the content of that
spec will be used for generating the tools.
- If apihub_resource_name includes only an api or a version name, the
first spec of the first version of that API will be used.
"""
def __init__(
self,
*,
# Parameters for fetching API Hub resource
apihub_resource_name: str,
access_token: Optional[str] = None,
service_account_json: Optional[str] = None,
# Parameters for the toolset itself
name: str = '',
description: str = '',
# Parameters for generating tools
lazy_load_spec=False,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
# Optionally, you can provide a custom API Hub client
apihub_client: Optional[APIHubClient] = None,
):
"""Initializes the APIHubTool with the given parameters.
Examples:
```
apihub_toolset = APIHubToolset(
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
service_account_json="...",
)
# Get all available tools
agent = LlmAgent(tools=apihub_toolset.get_tools())
# Get a specific tool
agent = LlmAgent(tools=[
...
apihub_toolset.get_tool('my_tool'),
])
```
**apihub_resource_name** is the resource name from API Hub. It must include
API name, and can optionally include API version and spec name.
- If apihub_resource_name includes a spec resource name, the content of that
spec will be used for generating the tools.
- If apihub_resource_name includes only an api or a version name, the
first spec of the first version of that API will be used.
Example:
* projects/xxx/locations/us-central1/apis/apiname/...
* https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx
Args:
apihub_resource_name: The resource name of the API in API Hub.
Example: `projects/test-project/locations/us-central1/apis/test-api`.
access_token: Google Access token. Generate with gcloud cli `gcloud auth
auth print-access-token`. Used for fetching API Specs from API Hub.
service_account_json: The service account config as a json string.
Required if not using default service credential. It is used for
creating the API Hub client and fetching the API Specs from API Hub.
apihub_client: Optional custom API Hub client.
name: Name of the toolset. Optional.
description: Description of the toolset. Optional.
auth_scheme: Auth scheme that applies to all the tool in the toolset.
auth_credential: Auth credential that applies to all the tool in the
toolset.
lazy_load_spec: If True, the spec will be loaded lazily when needed.
Otherwise, the spec will be loaded immediately and the tools will be
generated during initialization.
"""
self.name = name
self.description = description
self.apihub_resource_name = apihub_resource_name
self.lazy_load_spec = lazy_load_spec
self.apihub_client = apihub_client or APIHubClient(
access_token=access_token,
service_account_json=service_account_json,
)
self.generated_tools: Dict[str, RestApiTool] = {}
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
if not self.lazy_load_spec:
self._prepare_tools()
def get_tool(self, name: str) -> Optional[RestApiTool]:
"""Retrieves a specific tool by its name.
Example:
```
apihub_tool = apihub_toolset.get_tool('my_tool')
```
Args:
name: The name of the tool to retrieve.
Returns:
The tool with the given name, or None if no such tool exists.
"""
if not self._are_tools_ready():
self._prepare_tools()
return self.generated_tools[name] if name in self.generated_tools else None
def get_tools(self) -> List[RestApiTool]:
"""Retrieves all available tools.
Returns:
A list of all available RestApiTool objects.
"""
if not self._are_tools_ready():
self._prepare_tools()
return list(self.generated_tools.values())
def _are_tools_ready(self) -> bool:
return not self.lazy_load_spec or self.generated_tools
def _prepare_tools(self) -> str:
"""Fetches the spec from API Hub and generates the tools.
Returns:
True if the tools are ready, False otherwise.
"""
# For each API, get the first version and the first spec of that version.
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
self.generated_tools: Dict[str, RestApiTool] = {}
tools = self._parse_spec_to_tools(spec)
for tool in tools:
self.generated_tools[tool.name] = tool
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
"""Parses the spec string to a list of RestApiTool.
Args:
spec_str: The spec string to parse.
Returns:
A list of RestApiTool objects.
"""
spec_dict = yaml.safe_load(spec_str)
if not spec_dict:
return []
self.name = self.name or to_snake_case(
spec_dict.get('info', {}).get('title', 'unnamed')
)
self.description = self.description or spec_dict.get('info', {}).get(
'description', ''
)
tools = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=self.auth_credential,
auth_scheme=self.auth_scheme,
).get_tools()
return tools

View File

@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,332 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import base64
import json
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlparse
from google.auth import default as default_service_credential
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import requests
class BaseAPIHubClient(ABC):
"""Base class for API Hub clients."""
@abstractmethod
def get_spec_content(self, resource_name: str) -> str:
"""From a given resource name, get the soec in the API Hub."""
raise NotImplementedError()
class APIHubClient(BaseAPIHubClient):
"""Client for interacting with the API Hub service."""
def __init__(
self,
*,
access_token: Optional[str] = None,
service_account_json: Optional[str] = None,
):
"""Initializes the APIHubClient.
You must set either access_token or service_account_json. This
credential is used for sending request to API Hub API.
Args:
access_token: Google Access token. Generate with gcloud cli `gcloud auth
print-access-token`. Useful for local testing.
service_account_json: The service account configuration as a dictionary.
Required if not using default service credential.
"""
self.root_url = "https://apihub.googleapis.com/v1"
self.credential_cache = None
self.access_token, self.service_account = None, None
if access_token:
self.access_token = access_token
elif service_account_json:
self.service_account = service_account_json
def get_spec_content(self, path: str) -> str:
"""From a given path, get the first spec available in the API Hub.
- If path includes /apis/apiname, get the first spec of that API
- If path includes /apis/apiname/versions/versionname, get the first spec
of that API Version
- If path includes /apis/apiname/versions/versionname/specs/specname, return
that spec
Path can be resource name (projects/xxx/locations/us-central1/apis/apiname),
and URL from the UI
(https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx)
Args:
path: The path to the API, API Version, or API Spec.
Returns:
The content of the first spec available in the API Hub.
"""
apihub_resource_name, api_version_resource_name, api_spec_resource_name = (
self._extract_resource_name(path)
)
if apihub_resource_name and not api_version_resource_name:
api = self.get_api(apihub_resource_name)
versions = api.get("versions", [])
if not versions:
raise ValueError(
f"No versions found in API Hub resource: {apihub_resource_name}"
)
api_version_resource_name = versions[0]
if api_version_resource_name and not api_spec_resource_name:
api_version = self.get_api_version(api_version_resource_name)
spec_resource_names = api_version.get("specs", [])
if not spec_resource_names:
raise ValueError(
f"No specs found in API Hub version: {api_version_resource_name}"
)
api_spec_resource_name = spec_resource_names[0]
if api_spec_resource_name:
spec_content = self._fetch_spec(api_spec_resource_name)
return spec_content
raise ValueError("No API Hub resource found in path: {path}")
def list_apis(self, project: str, location: str) -> List[Dict[str, Any]]:
"""Lists all APIs in the specified project and location.
Args:
project: The Google Cloud project name.
location: The location of the API Hub resources (e.g., 'us-central1').
Returns:
A list of API dictionaries, or an empty list if an error occurs.
"""
url = f"{self.root_url}/projects/{project}/locations/{location}/apis"
headers = {
"accept": "application/json, text/plain, */*",
"Authorization": f"Bearer {self._get_access_token()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
apis = response.json().get("apis", [])
return apis
def get_api(self, api_resource_name: str) -> Dict[str, Any]:
"""Get API detail by API name.
Args:
api_resource_name: Resource name of this API, like
projects/xxx/locations/us-central1/apis/apiname
Returns:
An API and details in a dict.
"""
url = f"{self.root_url}/{api_resource_name}"
headers = {
"accept": "application/json, text/plain, */*",
"Authorization": f"Bearer {self._get_access_token()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
apis = response.json()
return apis
def get_api_version(self, api_version_name: str) -> Dict[str, Any]:
"""Gets details of a specific API version.
Args:
api_version_name: The resource name of the API version.
Returns:
The API version details as a dictionary, or an empty dictionary if an
error occurs.
"""
url = f"{self.root_url}/{api_version_name}"
headers = {
"accept": "application/json, text/plain, */*",
"Authorization": f"Bearer {self._get_access_token()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.json()
def _fetch_spec(self, api_spec_resource_name: str) -> str:
"""Retrieves the content of a specific API specification.
Args:
api_spec_resource_name: The resource name of the API spec.
Returns:
The decoded content of the specification as a string, or an empty string
if an error occurs.
"""
url = f"{self.root_url}/{api_spec_resource_name}:contents"
headers = {
"accept": "application/json, text/plain, */*",
"Authorization": f"Bearer {self._get_access_token()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
content_base64 = response.json().get("contents", "")
if content_base64:
content_decoded = base64.b64decode(content_base64).decode("utf-8")
return content_decoded
else:
return ""
def _extract_resource_name(self, url_or_path: str) -> Tuple[str, str, str]:
"""Extracts the resource names of an API, API Version, and API Spec from a given URL or path.
Args:
url_or_path: The URL (UI or resource) or path string.
Returns:
A dictionary containing the resource names:
{
"api_resource_name": "projects/*/locations/*/apis/*",
"api_version_resource_name":
"projects/*/locations/*/apis/*/versions/*",
"api_spec_resource_name":
"projects/*/locations/*/apis/*/versions/*/specs/*"
}
or raises ValueError if extraction fails.
Raises:
ValueError: If the URL or path is invalid or if required components
(project, location, api) are missing.
"""
query_params = None
try:
parsed_url = urlparse(url_or_path)
path = parsed_url.path
query_params = parse_qs(parsed_url.query)
# This is a path from UI. Remove unnecessary prefix.
if "api-hub/" in path:
path = path.split("api-hub")[1]
except Exception:
path = url_or_path
path_segments = [segment for segment in path.split("/") if segment]
project = None
location = None
api_id = None
version_id = None
spec_id = None
if "projects" in path_segments:
project_index = path_segments.index("projects")
if project_index + 1 < len(path_segments):
project = path_segments[project_index + 1]
elif query_params and "project" in query_params:
project = query_params["project"][0]
if not project:
raise ValueError(
"Project ID not found in URL or path in APIHubClient. Input path is"
f" '{url_or_path}'. Please make sure there is either"
" '/projects/PROJECT_ID' in the path or 'project=PROJECT_ID' query"
" param in the input."
)
if "locations" in path_segments:
location_index = path_segments.index("locations")
if location_index + 1 < len(path_segments):
location = path_segments[location_index + 1]
if not location:
raise ValueError(
"Location not found in URL or path in APIHubClient. Input path is"
f" '{url_or_path}'. Please make sure there is either"
" '/location/LOCATION_ID' in the path."
)
if "apis" in path_segments:
api_index = path_segments.index("apis")
if api_index + 1 < len(path_segments):
api_id = path_segments[api_index + 1]
if not api_id:
raise ValueError(
"API id not found in URL or path in APIHubClient. Input path is"
f" '{url_or_path}'. Please make sure there is either"
" '/apis/API_ID' in the path."
)
if "versions" in path_segments:
version_index = path_segments.index("versions")
if version_index + 1 < len(path_segments):
version_id = path_segments[version_index + 1]
if "specs" in path_segments:
spec_index = path_segments.index("specs")
if spec_index + 1 < len(path_segments):
spec_id = path_segments[spec_index + 1]
api_resource_name = f"projects/{project}/locations/{location}/apis/{api_id}"
api_version_resource_name = (
f"{api_resource_name}/versions/{version_id}" if version_id else None
)
api_spec_resource_name = (
f"{api_version_resource_name}/specs/{spec_id}"
if version_id and spec_id
else None
)
return (
api_resource_name,
api_version_resource_name,
api_spec_resource_name,
)
def _get_access_token(self) -> str:
"""Gets the access token for the service account.
Returns:
The access token.
"""
if self.access_token:
return self.access_token
if self.credential_cache and not self.credential_cache.expired:
return self.credential_cache.token
if self.service_account:
try:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.service_account),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid service account JSON: {e}") from e
else:
try:
credentials, _ = default_service_credential()
except:
credentials = None
if not credentials:
raise ValueError(
"Please provide a service account or an access token to API Hub"
" client."
)
credentials.refresh(Request())
self.credential_cache = credentials
return credentials.token

View File

@@ -0,0 +1,115 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional
import google.auth
from google.auth import default as default_service_credential
import google.auth.transport.requests
from google.cloud import secretmanager
from google.oauth2 import service_account
class SecretManagerClient:
"""A client for interacting with Google Cloud Secret Manager.
This class provides a simplified interface for retrieving secrets from
Secret Manager, handling authentication using either a service account
JSON keyfile (passed as a string) or a pre-existing authorization token.
Attributes:
_credentials: Google Cloud credentials object (ServiceAccountCredentials
or Credentials).
_client: Secret Manager client instance.
"""
def __init__(
self,
service_account_json: Optional[str] = None,
auth_token: Optional[str] = None,
):
"""Initializes the SecretManagerClient.
Args:
service_account_json: The content of a service account JSON keyfile (as
a string), not the file path. Must be valid JSON.
auth_token: An existing Google Cloud authorization token.
Raises:
ValueError: If neither `service_account_json` nor `auth_token` is
provided,
or if both are provided. Also raised if the service_account_json
is not valid JSON.
google.auth.exceptions.GoogleAuthError: If authentication fails.
"""
if service_account_json:
try:
credentials = service_account.Credentials.from_service_account_info(
json.loads(service_account_json)
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid service account JSON: {e}") from e
elif auth_token:
credentials = google.auth.credentials.Credentials(
token=auth_token,
refresh_token=None,
token_uri=None,
client_id=None,
client_secret=None,
)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
else:
try:
credentials, _ = default_service_credential()
except Exception as e:
raise ValueError(
"'service_account_json' or 'auth_token' are both missing, and"
f" error occurred while trying to use default credentials: {e}"
) from e
if not credentials:
raise ValueError(
"Must provide either 'service_account_json' or 'auth_token', not both"
" or neither."
)
self._credentials = credentials
self._client = secretmanager.SecretManagerServiceClient(
credentials=self._credentials
)
def get_secret(self, resource_name: str) -> str:
"""Retrieves a secret from Google Cloud Secret Manager.
Args:
resource_name: The full resource name of the secret, in the format
"projects/*/secrets/*/versions/*". Usually you want the "latest"
version, e.g.,
"projects/my-project/secrets/my-secret/versions/latest".
Returns:
The secret payload as a string.
Raises:
google.api_core.exceptions.GoogleAPIError: If the Secret Manager API
returns an error (e.g., secret not found, permission denied).
Exception: For other unexpected errors.
"""
try:
response = self._client.access_secret_version(name=resource_name)
return response.payload.data.decode("UTF-8")
except Exception as e:
raise e # Re-raise the exception to allow for handling by the caller
# Consider logging the exception here before re-raising.

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .application_integration_toolset import ApplicationIntegrationToolset
__all__ = [
'ApplicationIntegrationToolset',
]

View File

@@ -0,0 +1,230 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Optional
from fastapi.openapi.models import HTTPBearer
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential
# TODO(cheliu): Apply a common toolset interface
class ApplicationIntegrationToolset:
"""ApplicationIntegrationToolset generates tools from a given Application
Integration or Integration Connector resource.
Example Usage:
```
# Get all available tools for an integration with api trigger
application_integration_toolset = ApplicationIntegrationToolset(
project="test-project",
location="us-central1"
integration="test-integration",
trigger="api_trigger/test_trigger",
service_account_credentials={...},
)
# Get all available tools for a connection using entity operations and
# actions
# Note: Find the list of supported entity operations and actions for a
connection
# using integration connector apis:
#
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
application_integration_toolset = ApplicationIntegrationToolset(
project="test-project",
location="us-central1"
connection="test-connection",
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
#empty list for actions means all operations on the entity are supported
actions=["action1"],
service_account_credentials={...},
)
# Get all available tools
agent = LlmAgent(tools=[
...
*application_integration_toolset.get_tools(),
])
```
"""
def __init__(
self,
project: str,
location: str,
integration: Optional[str] = None,
trigger: Optional[str] = None,
connection: Optional[str] = None,
entity_operations: Optional[str] = None,
actions: Optional[str] = None,
# Optional parameter for the toolset. This is prepended to the generated
# tool/python function name.
tool_name: Optional[str] = "",
# Optional parameter for the toolset. This is appended to the generated
# tool/python function description.
tool_instructions: Optional[str] = "",
service_account_json: Optional[str] = None,
):
"""Initializes the ApplicationIntegrationToolset.
Example Usage:
```
# Get all available tools for an integration with api trigger
application_integration_toolset = ApplicationIntegrationToolset(
project="test-project",
location="us-central1"
integration="test-integration",
trigger="api_trigger/test_trigger",
service_account_credentials={...},
)
# Get all available tools for a connection using entity operations and
# actions
# Note: Find the list of supported entity operations and actions for a
connection
# using integration connector apis:
#
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
application_integration_toolset = ApplicationIntegrationToolset(
project="test-project",
location="us-central1"
connection="test-connection",
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
#empty list for actions means all operations on the entity are supported
actions=["action1"],
service_account_credentials={...},
)
# Get all available tools
agent = LlmAgent(tools=[
...
*application_integration_toolset.get_tools(),
])
```
Args:
project: The GCP project ID.
location: The GCP location.
integration: The integration name.
trigger: The trigger name.
connection: The connection name.
entity_operations: The entity operations supported by the connection.
actions: The actions supported by the connection.
tool_name: The name of the tool.
tool_instructions: The instructions for the tool.
service_account_json: The service account configuration as a dictionary.
Required if not using default service credential. Used for fetching
the Application Integration or Integration Connector resource.
Raises:
ValueError: If neither integration and trigger nor connection and
(entity_operations or actions) is provided.
Exception: If there is an error during the initialization of the
integration or connection client.
"""
self.project = project
self.location = location
self.integration = integration
self.trigger = trigger
self.connection = connection
self.entity_operations = entity_operations
self.actions = actions
self.tool_name = tool_name
self.tool_instructions = tool_instructions
self.service_account_json = service_account_json
self.generated_tools: Dict[str, RestApiTool] = {}
integration_client = IntegrationClient(
project,
location,
integration,
trigger,
connection,
entity_operations,
actions,
service_account_json,
)
if integration and trigger:
spec = integration_client.get_openapi_spec_for_integration()
elif connection and (entity_operations or actions):
connections_client = ConnectionsClient(
project, location, connection, service_account_json
)
connection_details = connections_client.get_connection_details()
tool_instructions += (
"ALWAYS use serviceName = "
+ connection_details["serviceName"]
+ ", host = "
+ connection_details["host"]
+ " and the connection name = "
+ f"projects/{project}/locations/{location}/connections/{connection} when"
" using this tool"
+ ". DONOT ask the user for these values as you already have those."
)
spec = integration_client.get_openapi_spec_for_connection(
tool_name,
tool_instructions,
)
else:
raise ValueError(
"Either (integration and trigger) or (connection and"
" (entity_operations or actions)) should be provided."
)
self._parse_spec_to_tools(spec)
def _parse_spec_to_tools(self, spec_dict):
"""Parses the spec dict to a list of RestApiTool."""
if self.service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json(
self.service_account_json
)
service_account = ServiceAccount(
service_account_credential=sa_credential,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
auth_scheme, auth_credential = service_account_scheme_credential(
config=service_account
)
else:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
),
)
auth_scheme = HTTPBearer(bearerFormat="JWT")
tools = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=auth_credential,
auth_scheme=auth_scheme,
).get_tools()
for tool in tools:
self.generated_tools[tool.name] = tool
def get_tools(self) -> List[RestApiTool]:
return list(self.generated_tools.values())

View File

@@ -0,0 +1,903 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import google.auth
from google.auth import default as default_service_credential
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import requests
class ConnectionsClient:
"""Utility class for interacting with Google Cloud Connectors API."""
def __init__(
self,
project: str,
location: str,
connection: str,
service_account_json: Optional[str] = None,
):
"""Initializes the ConnectionsClient.
Args:
project: The Google Cloud project ID.
location: The Google Cloud location (e.g., us-central1).
connection: The connection name.
service_account_json: The service account configuration as a dictionary.
Required if not using default service credential. Used for fetching
connection details.
"""
self.project = project
self.location = location
self.connection = connection
self.connector_url = "https://connectors.googleapis.com"
self.service_account_json = service_account_json
self.credential_cache = None
def get_connection_details(self) -> Dict[str, Any]:
"""Retrieves service details (service name and host) for a given connection.
Also returns if auth override is enabled for the connection.
Returns:
tuple: A tuple containing (service_name, host).
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request error.
Exception: For any other unexpected errors.
"""
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}?view=BASIC"
response = self._execute_api_call(url)
connection_data = response.json()
service_name = connection_data.get("serviceDirectory", "")
host = connection_data.get("host", "")
if host:
service_name = connection_data.get("tlsServiceDirectory", "")
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
return {
"serviceName": service_name,
"host": host,
"authOverrideEnabled": auth_override_enabled,
}
def get_entity_schema_and_operations(
self, entity: str
) -> Tuple[Dict[str, Any], List[str]]:
"""Retrieves the JSON schema for a given entity in a connection.
Args:
entity (str): The entity name.
Returns:
tuple: A tuple containing (schema, operations).
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request or processing error.
Exception: For any other unexpected errors.
"""
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getEntityType?entityId={entity}"
response = self._execute_api_call(url)
operation_id = response.json().get("name")
if not operation_id:
raise ValueError(
f"Failed to get entity schema and operations for entity: {entity}"
)
operation_response = self._poll_operation(operation_id)
schema = operation_response.get("response", {}).get("jsonSchema", {})
operations = operation_response.get("response", {}).get("operations", [])
return schema, operations
def get_action_schema(self, action: str) -> Dict[str, Any]:
"""Retrieves the input and output JSON schema for a given action in a connection.
Args:
action (str): The action name.
Returns:
tuple: A tuple containing (input_schema, output_schema).
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request or processing error.
Exception: For any other unexpected errors.
"""
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getAction?actionId={action}"
response = self._execute_api_call(url)
operation_id = response.json().get("name")
if not operation_id:
raise ValueError(f"Failed to get action schema for action: {action}")
operation_response = self._poll_operation(operation_id)
input_schema = operation_response.get("response", {}).get(
"inputJsonSchema", {}
)
output_schema = operation_response.get("response", {}).get(
"outputJsonSchema", {}
)
description = operation_response.get("response", {}).get("description", "")
display_name = operation_response.get("response", {}).get("displayName", "")
return {
"inputSchema": input_schema,
"outputSchema": output_schema,
"description": description,
"displayName": display_name,
}
@staticmethod
def get_connector_base_spec() -> Dict[str, Any]:
return {
"openapi": "3.0.1",
"info": {
"title": "ExecuteConnection",
"description": "This tool can execute a query on connection",
"version": "4",
},
"servers": [{"url": "https://integrations.googleapis.com"}],
"security": [
{"google_auth": ["https://www.googleapis.com/auth/cloud-platform"]}
],
"paths": {},
"components": {
"schemas": {
"operation": {
"type": "string",
"default": "LIST_ENTITIES",
"description": (
"Operation to execute. Possible values are"
" LIST_ENTITIES, GET_ENTITY, CREATE_ENTITY,"
" UPDATE_ENTITY, DELETE_ENTITY in case of entities."
" EXECUTE_ACTION in case of actions. and EXECUTE_QUERY"
" in case of custom queries."
),
},
"entityId": {
"type": "string",
"description": "Name of the entity",
},
"connectorInputPayload": {"type": "object"},
"filterClause": {
"type": "string",
"default": "",
"description": "WHERE clause in SQL query",
},
"pageSize": {
"type": "integer",
"default": 50,
"description": (
"Number of entities to return in the response"
),
},
"pageToken": {
"type": "string",
"default": "",
"description": (
"Page token to return the next page of entities"
),
},
"connectionName": {
"type": "string",
"default": "",
"description": (
"Connection resource name to run the query for"
),
},
"serviceName": {
"type": "string",
"default": "",
"description": "Service directory for the connection",
},
"host": {
"type": "string",
"default": "",
"description": "Host name incase of tls service directory",
},
"entity": {
"type": "string",
"default": "Issues",
"description": "Entity to run the query for",
},
"action": {
"type": "string",
"default": "ExecuteCustomQuery",
"description": "Action to run the query for",
},
"query": {
"type": "string",
"default": "",
"description": "Custom Query to execute on the connection",
},
"dynamicAuthConfig": {
"type": "object",
"default": {},
"description": "Dynamic auth config for the connection",
},
"timeout": {
"type": "integer",
"default": 120,
"description": (
"Timeout in seconds for execution of custom query"
),
},
"connectorOutputPayload": {"type": "object"},
"nextPageToken": {"type": "string"},
"execute-connector_Response": {
"required": ["connectorOutputPayload"],
"type": "object",
"properties": {
"connectorOutputPayload": {
"$ref": (
"#/components/schemas/connectorOutputPayload"
)
},
"nextPageToken": {
"$ref": "#/components/schemas/nextPageToken"
},
},
},
},
"securitySchemes": {
"google_auth": {
"type": "oauth2",
"flows": {
"implicit": {
"authorizationUrl": (
"https://accounts.google.com/o/oauth2/auth"
),
"scopes": {
"https://www.googleapis.com/auth/cloud-platform": (
"Auth for google cloud services"
)
},
}
},
}
},
},
}
@staticmethod
def get_action_operation(
action: str,
operation: str,
action_display_name: str,
tool_name: str = "",
tool_instructions: str = "",
) -> Dict[str, Any]:
description = (
f"Use this tool with" f' action = "{action}" and'
) + f' operation = "{operation}" only. Dont ask these values from user.'
if operation == "EXECUTE_QUERY":
description = (
(f"Use this tool with" f' action = "{action}" and')
+ f' operation = "{operation}" only. Dont ask these values from user.'
" Use pageSize = 50 and timeout = 120 until user specifies a"
" different value otherwise. If user provides a query in natural"
" language, convert it to SQL query and then execute it using the"
" tool."
)
return {
"post": {
"summary": f"{action_display_name}",
"description": f"{description} {tool_instructions}",
"operationId": f"{tool_name}_{action_display_name}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/{action_display_name}_Request"
)
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/{action_display_name}_Response"
),
}
}
},
}
},
}
}
@staticmethod
def list_operation(
entity: str,
schema_as_string: str = "",
tool_name: str = "",
tool_instructions: str = "",
) -> Dict[str, Any]:
return {
"post": {
"summary": f"List {entity}",
"description": (
f"Returns all entities of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
" from"
+ ' user. Always use ""'
+ ' as filter clause and ""'
+ " as page token and 50 as page size until user specifies a"
" different value otherwise. Use single quotes for strings in"
f" filter clause. {tool_instructions}"
),
"operationId": f"{tool_name}_list_{entity}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/list_{entity}_Request"
)
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"description": (
f"Returns a list of {entity} of json"
f" schema: {schema_as_string}"
),
"$ref": (
"#/components/schemas/execute-connector_Response"
),
}
}
},
}
},
}
}
@staticmethod
def get_operation(
entity: str,
schema_as_string: str = "",
tool_name: str = "",
tool_instructions: str = "",
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Get {entity}",
"description": (
(
f"Returns the details of the {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
f" user. {tool_instructions}"
),
"operationId": f"{tool_name}_get_{entity}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": f"#/components/schemas/get_{entity}_Request"
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"description": (
f"Returns {entity} of json schema:"
f" {schema_as_string}"
),
"$ref": (
"#/components/schemas/execute-connector_Response"
),
}
}
},
}
},
}
}
@staticmethod
def create_operation(
entity: str, tool_name: str = "", tool_instructions: str = ""
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Create {entity}",
"description": (
(
f"Creates a new entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Follow the schema of the entity provided in the"
f" instructions to create {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_create_{entity}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/create_{entity}_Request"
)
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"$ref": (
"#/components/schemas/execute-connector_Response"
)
}
}
},
}
},
}
}
@staticmethod
def update_operation(
entity: str, tool_name: str = "", tool_instructions: str = ""
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Update {entity}",
"description": (
(
f"Updates an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Use entityId to uniquely identify the entity to"
" update. Follow the schema of the entity provided in the"
f" instructions to update {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_update_{entity}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/update_{entity}_Request"
)
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"$ref": (
"#/components/schemas/execute-connector_Response"
)
}
}
},
}
},
}
}
@staticmethod
def delete_operation(
entity: str, tool_name: str = "", tool_instructions: str = ""
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Delete {entity}",
"description": (
(
f"Deletes an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
" from"
f" user. {tool_instructions}"
),
"operationId": f"{tool_name}_delete_{entity}",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": (
f"#/components/schemas/delete_{entity}_Request"
)
}
}
}
},
"responses": {
"200": {
"description": "Success response",
"content": {
"application/json": {
"schema": {
"$ref": (
"#/components/schemas/execute-connector_Response"
)
}
}
},
}
},
}
}
@staticmethod
def create_operation_request(entity: str) -> Dict[str, Any]:
return {
"type": "object",
"required": [
"connectorInputPayload",
"operation",
"connectionName",
"serviceName",
"host",
"entity",
],
"properties": {
"connectorInputPayload": {
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
},
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
},
}
@staticmethod
def update_operation_request(entity: str) -> Dict[str, Any]:
return {
"type": "object",
"required": [
"connectorInputPayload",
"entityId",
"operation",
"connectionName",
"serviceName",
"host",
"entity",
],
"properties": {
"connectorInputPayload": {
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
},
"entityId": {"$ref": "#/components/schemas/entityId"},
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
},
}
@staticmethod
def get_operation_request() -> Dict[str, Any]:
return {
"type": "object",
"required": [
"entityId",
"operation",
"connectionName",
"serviceName",
"host",
"entity",
],
"properties": {
"entityId": {"$ref": "#/components/schemas/entityId"},
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
},
}
@staticmethod
def delete_operation_request() -> Dict[str, Any]:
return {
"type": "object",
"required": [
"entityId",
"operation",
"connectionName",
"serviceName",
"host",
"entity",
],
"properties": {
"entityId": {"$ref": "#/components/schemas/entityId"},
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
},
}
@staticmethod
def list_operation_request() -> Dict[str, Any]:
return {
"type": "object",
"required": [
"operation",
"connectionName",
"serviceName",
"host",
"entity",
],
"properties": {
"filterClause": {"$ref": "#/components/schemas/filterClause"},
"pageSize": {"$ref": "#/components/schemas/pageSize"},
"pageToken": {"$ref": "#/components/schemas/pageToken"},
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"entity": {"$ref": "#/components/schemas/entity"},
},
}
@staticmethod
def action_request(action: str) -> Dict[str, Any]:
return {
"type": "object",
"required": [
"operation",
"connectionName",
"serviceName",
"host",
"action",
"connectorInputPayload",
],
"properties": {
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"action": {"$ref": "#/components/schemas/action"},
"connectorInputPayload": {
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
},
},
}
@staticmethod
def action_response(action: str) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"connectorOutputPayload": {
"$ref": f"#/components/schemas/connectorOutputPayload_{action}"
},
},
}
@staticmethod
def execute_custom_query_request() -> Dict[str, Any]:
return {
"type": "object",
"required": [
"operation",
"connectionName",
"serviceName",
"host",
"action",
"query",
"timeout",
"pageSize",
],
"properties": {
"operation": {"$ref": "#/components/schemas/operation"},
"connectionName": {"$ref": "#/components/schemas/connectionName"},
"serviceName": {"$ref": "#/components/schemas/serviceName"},
"host": {"$ref": "#/components/schemas/host"},
"action": {"$ref": "#/components/schemas/action"},
"query": {"$ref": "#/components/schemas/query"},
"timeout": {"$ref": "#/components/schemas/timeout"},
"pageSize": {"$ref": "#/components/schemas/pageSize"},
},
}
def connector_payload(self, json_schema: Dict[str, Any]) -> Dict[str, Any]:
return self._convert_json_schema_to_openapi_schema(json_schema)
def _convert_json_schema_to_openapi_schema(self, json_schema):
"""Converts a JSON schema dictionary to an OpenAPI schema dictionary, handling variable types, properties, items, nullable, and description.
Args:
json_schema (dict): The input JSON schema dictionary.
Returns:
dict: The converted OpenAPI schema dictionary.
"""
openapi_schema = {}
if "description" in json_schema:
openapi_schema["description"] = json_schema["description"]
if "type" in json_schema:
if isinstance(json_schema["type"], list):
if "null" in json_schema["type"]:
openapi_schema["nullable"] = True
other_types = [t for t in json_schema["type"] if t != "null"]
if other_types:
openapi_schema["type"] = other_types[0]
else:
openapi_schema["type"] = json_schema["type"][0]
else:
openapi_schema["type"] = json_schema["type"]
if openapi_schema.get("type") == "object" and "properties" in json_schema:
openapi_schema["properties"] = {}
for prop_name, prop_schema in json_schema["properties"].items():
openapi_schema["properties"][prop_name] = (
self._convert_json_schema_to_openapi_schema(prop_schema)
)
elif openapi_schema.get("type") == "array" and "items" in json_schema:
if isinstance(json_schema["items"], list):
openapi_schema["items"] = [
self._convert_json_schema_to_openapi_schema(item)
for item in json_schema["items"]
]
else:
openapi_schema["items"] = self._convert_json_schema_to_openapi_schema(
json_schema["items"]
)
return openapi_schema
def _get_access_token(self) -> str:
"""Gets the access token for the service account.
Returns:
The access token.
"""
if self.credential_cache and not self.credential_cache.expired:
return self.credential_cache.token
if self.service_account_json:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.service_account_json),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
try:
credentials, _ = default_service_credential()
except:
credentials = None
if not credentials:
raise ValueError(
"Please provide a service account that has the required permissions"
" to access the connection."
)
credentials.refresh(Request())
self.credential_cache = credentials
return credentials.token
def _execute_api_call(self, url):
"""Executes an API call to the given URL.
Args:
url (str): The URL to call.
Returns:
requests.Response: The response object from the API call.
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request error.
Exception: For any other unexpected errors.
"""
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._get_access_token()}",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response
except google.auth.exceptions.DefaultCredentialsError as e:
raise PermissionError(f"Credentials error: {e}") from e
except requests.exceptions.RequestException as e:
if (
"404" in str(e)
or "Not found" in str(e)
or "400" in str(e)
or "Bad request" in str(e)
):
raise ValueError(
"Invalid request. Please check the provided"
f" values of project({self.project}), location({self.location}),"
f" connection({self.connection})."
) from e
raise ValueError(f"Request error: {e}") from e
except Exception as e:
raise Exception(f"An unexpected error occurred: {e}") from e
def _poll_operation(self, operation_id: str) -> Dict[str, Any]:
"""Polls an operation until it is done.
Args:
operation_id: The ID of the operation to poll.
Returns:
The final response of the operation.
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request error.
Exception: For any other unexpected errors.
"""
operation_done: bool = False
operation_response: Dict[str, Any] = {}
while not operation_done:
get_operation_url = f"{self.connector_url}/v1/{operation_id}"
response = self._execute_api_call(get_operation_url)
operation_response = response.json()
operation_done = operation_response.get("done", False)
time.sleep(1)
return operation_response

View File

@@ -0,0 +1,253 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
import google.auth
from google.auth import default as default_service_credential
import google.auth.transport.requests
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import requests
class IntegrationClient:
"""A client for interacting with Google Cloud Application Integration.
This class provides methods for retrieving OpenAPI spec for an integration or
a connection.
"""
def __init__(
self,
project: str,
location: str,
integration: Optional[str] = None,
trigger: Optional[str] = None,
connection: Optional[str] = None,
entity_operations: Optional[dict[str, list[str]]] = None,
actions: Optional[list[str]] = None,
service_account_json: Optional[str] = None,
):
"""Initializes the ApplicationIntegrationClient.
Args:
project: The Google Cloud project ID.
location: The Google Cloud location (e.g., us-central1).
integration: The integration name.
trigger: The trigger ID for the integration.
connection: The connection name.
entity_operations: A dictionary mapping entity names to a list of
operations (e.g., LIST, CREATE, UPDATE, DELETE, GET).
actions: List of actions.
service_account_json: The service account configuration as a dictionary.
Required if not using default service credential. Used for fetching
connection details.
"""
self.project = project
self.location = location
self.integration = integration
self.trigger = trigger
self.connection = connection
self.entity_operations = (
entity_operations if entity_operations is not None else {}
)
self.actions = actions if actions is not None else []
self.service_account_json = service_account_json
self.credential_cache = None
def get_openapi_spec_for_integration(self):
"""Gets the OpenAPI spec for the integration.
Returns:
dict: The OpenAPI spec as a dictionary.
Raises:
PermissionError: If there are credential issues.
ValueError: If there's a request error or processing error.
Exception: For any other unexpected errors.
"""
try:
url = f"https://{self.location}-integrations.googleapis.com/v1/projects/{self.project}/locations/{self.location}:generateOpenApiSpec"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._get_access_token()}",
}
data = {
"apiTriggerResources": [
{
"integrationResource": self.integration,
"triggerId": [self.trigger],
},
],
"fileFormat": "JSON",
}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
spec = response.json().get("openApiSpec", {})
return json.loads(spec)
except google.auth.exceptions.DefaultCredentialsError as e:
raise PermissionError(f"Credentials error: {e}") from e
except requests.exceptions.RequestException as e:
if (
"404" in str(e)
or "Not found" in str(e)
or "400" in str(e)
or "Bad request" in str(e)
):
raise ValueError(
"Invalid request. Please check the provided values of"
f" project({self.project}), location({self.location}),"
f" integration({self.integration}) and trigger({self.trigger})."
) from e
raise ValueError(f"Request error: {e}") from e
except Exception as e:
raise Exception(f"An unexpected error occurred: {e}") from e
def get_openapi_spec_for_connection(self, tool_name="", tool_instructions=""):
"""Gets the OpenAPI spec for the connection.
Returns:
dict: The OpenAPI spec as a dictionary.
Raises:
ValueError: If there's an error retrieving the OpenAPI spec.
PermissionError: If there are credential issues.
Exception: For any other unexpected errors.
"""
# Application Integration needs to be provisioned in the same region as connection and an integration with name "ExecuteConnection" and trigger "api_trigger/ExecuteConnection" should be created as per the documentation.
integration_name = "ExecuteConnection"
connections_client = ConnectionsClient(
self.project,
self.location,
self.connection,
self.service_account_json,
)
if not self.entity_operations and not self.actions:
raise ValueError(
"No entity operations or actions provided. Please provide at least"
" one of them."
)
connector_spec = connections_client.get_connector_base_spec()
for entity, operations in self.entity_operations.items():
schema, supported_operations = (
connections_client.get_entity_schema_and_operations(entity)
)
if not operations:
operations = supported_operations
json_schema_as_string = json.dumps(schema)
entity_lower = entity
connector_spec["components"]["schemas"][
f"connectorInputPayload_{entity_lower}"
] = connections_client.connector_payload(schema)
for operation in operations:
operation_lower = operation.lower()
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{operation_lower}_{entity_lower}"
if operation_lower == "create":
connector_spec["paths"][path] = connections_client.create_operation(
entity_lower, tool_name, tool_instructions
)
connector_spec["components"]["schemas"][
f"create_{entity_lower}_Request"
] = connections_client.create_operation_request(entity_lower)
elif operation_lower == "update":
connector_spec["paths"][path] = connections_client.update_operation(
entity_lower, tool_name, tool_instructions
)
connector_spec["components"]["schemas"][
f"update_{entity_lower}_Request"
] = connections_client.update_operation_request(entity_lower)
elif operation_lower == "delete":
connector_spec["paths"][path] = connections_client.delete_operation(
entity_lower, tool_name, tool_instructions
)
connector_spec["components"]["schemas"][
f"delete_{entity_lower}_Request"
] = connections_client.delete_operation_request()
elif operation_lower == "list":
connector_spec["paths"][path] = connections_client.list_operation(
entity_lower, json_schema_as_string, tool_name, tool_instructions
)
connector_spec["components"]["schemas"][
f"list_{entity_lower}_Request"
] = connections_client.list_operation_request()
elif operation_lower == "get":
connector_spec["paths"][path] = connections_client.get_operation(
entity_lower, json_schema_as_string, tool_name, tool_instructions
)
connector_spec["components"]["schemas"][
f"get_{entity_lower}_Request"
] = connections_client.get_operation_request()
else:
raise ValueError(
f"Invalid operation: {operation} for entity: {entity}"
)
for action in self.actions:
action_details = connections_client.get_action_schema(action)
input_schema = action_details["inputSchema"]
output_schema = action_details["outputSchema"]
action_display_name = action_details["displayName"]
operation = "EXECUTE_ACTION"
if action == "ExecuteCustomQuery":
connector_spec["components"]["schemas"][
f"{action}_Request"
] = connections_client.execute_custom_query_request()
operation = "EXECUTE_QUERY"
else:
connector_spec["components"]["schemas"][
f"{action_display_name}_Request"
] = connections_client.action_request(action_display_name)
connector_spec["components"]["schemas"][
f"connectorInputPayload_{action_display_name}"
] = connections_client.connector_payload(input_schema)
connector_spec["components"]["schemas"][
f"connectorOutputPayload_{action_display_name}"
] = connections_client.connector_payload(output_schema)
connector_spec["components"]["schemas"][
f"{action_display_name}_Response"
] = connections_client.action_response(action_display_name)
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{action}"
connector_spec["paths"][path] = connections_client.get_action_operation(
action, operation, action_display_name, tool_name, tool_instructions
)
return connector_spec
def _get_access_token(self) -> str:
"""Gets the access token for the service account or using default credentials.
Returns:
The access token.
"""
if self.credential_cache and not self.credential_cache.expired:
return self.credential_cache.token
if self.service_account_json:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.service_account_json),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
try:
credentials, _ = default_service_credential()
except:
credentials = None
if not credentials:
raise ValueError(
"Please provide a service account that has the required permissions"
" to access the connection."
)
credentials.refresh(Request())
self.credential_cache = credentials
return credentials.token

View File

@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import ABC
import os
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from deprecated import deprecated
from google.genai import types
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
class BaseTool(ABC):
"""The base class for all tools."""
name: str
"""The name of the tool."""
description: str
"""The description of the tool."""
is_long_running: bool = False
"""Whether the tool is a long running operation, which typically returns a
resource id first and finishes the operation later."""
def __init__(self, *, name, description, is_long_running: bool = False):
self.name = name
self.description = description
self.is_long_running = is_long_running
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
NOTE
- Required if subclass uses the default implementation of
`process_llm_request` to add function declaration to LLM request.
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
Gemini.
Returns:
The FunctionDeclaration of this tool, or None if it doesn't need to be
added to LlmRequest.config.
"""
return None
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
"""Runs the tool with the given arguments and context.
NOTE
- Required if this tool needs to run at the client side.
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
Gemini.
Args:
args: The LLM-filled arguments.
ctx: The context of the tool.
Returns:
The result of running the tool.
"""
raise NotImplementedError(f'{type(self)} is not implemented')
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
"""Processes the outgoing LLM request for this tool.
Use cases:
- Most common use case is adding this tool to the LLM request.
- Some tools may just preprocess the LLM request before it's sent out.
Args:
tool_context: The context of the tool.
llm_request: The outgoing LLM request, mutable this method.
"""
if (function_declaration := self._get_declaration()) is None:
return
llm_request.tools_dict[self.name] = self
if tool_with_function_declarations := _find_tool_with_function_declarations(
llm_request
):
if tool_with_function_declarations.function_declarations is None:
tool_with_function_declarations.function_declarations = []
tool_with_function_declarations.function_declarations.append(
function_declaration
)
else:
llm_request.config = (
types.GenerateContentConfig()
if not llm_request.config
else llm_request.config
)
llm_request.config.tools = (
[] if not llm_request.config.tools else llm_request.config.tools
)
llm_request.config.tools.append(
types.Tool(function_declarations=[function_declaration])
)
@property
def _api_variant(self) -> str:
use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
'true',
'1',
]
return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI'
def _find_tool_with_function_declarations(
llm_request: LlmRequest,
) -> Optional[types.Tool]:
# TODO: add individual tool with declaration and merge in google_llm.py
if not llm_request.config or not llm_request.config.tools:
return None
return next(
(
tool
for tool in llm_request.config.tools
if isinstance(tool, types.Tool) and tool.function_declarations
),
None,
)

View File

@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
class BuiltInCodeExecutionTool(BaseTool):
"""A built-in code execution tool that is automatically invoked by Gemini 2 models.
This tool operates internally within the model and does not require or perform
local code execution.
"""
def __init__(self):
# Name and description are not used because this is a model built-in tool.
super().__init__(name='code_execution', description='code_execution')
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
if llm_request.model and llm_request.model.startswith('gemini-2'):
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
types.Tool(code_execution=types.ToolCodeExecution())
)
else:
raise ValueError(
f'Code execution tool is not supported for model {llm_request.model}'
)
built_in_code_execution = BuiltInCodeExecutionTool()

View File

@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.genai import types
from typing_extensions import override
from . import _automatic_function_calling_util
from .function_tool import FunctionTool
try:
from crewai.tools import BaseTool as CrewaiBaseTool
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
"Crewai Tools require Python 3.10+. Please upgrade your Python version."
) from e
else:
raise ImportError(
"Crewai Tools require pip install 'google-adk[extensions]'."
) from e
class CrewaiTool(FunctionTool):
"""Use this class to wrap a CrewAI tool.
If the original tool name and description are not suitable, you can override
them in the constructor.
"""
tool: CrewaiBaseTool
"""The wrapped CrewAI tool."""
def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
super().__init__(tool.run)
self.tool = tool
if name:
self.name = name
elif tool.name:
# Right now, CrewAI tool name contains white spaces. White spaces are
# not supported in our framework. So we replace them with "_".
self.name = tool.name.replace(" ", "_").lower()
if description:
self.description = description
elif tool.description:
self.description = tool.description
@override
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool."""
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
False,
self.name,
self.description,
self.func,
self.tool.args_schema.model_json_schema(),
)
return function_declaration

View File

@@ -0,0 +1,62 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Union
from pydantic import TypeAdapter
from typing_extensions import override
from ..examples import example_util
from ..examples.base_example_provider import BaseExampleProvider
from ..examples.example import Example
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
class ExampleTool(BaseTool):
"""A tool that adds (few-shot) examples to the LLM request.
Attributes:
examples: The examples to add to the LLM request.
"""
def __init__(self, examples: Union[list[Example], BaseExampleProvider]):
# Name and description are not used because this tool only changes
# llm_request.
super().__init__(name='example_tool', description='example tool')
self.examples = (
TypeAdapter(list[Example]).validate_python(examples)
if isinstance(examples, list)
else examples
)
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
parts = tool_context.user_content.parts
if not parts or not parts[0].text:
return
llm_request.append_instructions([
example_util.build_example_si(
self.examples, parts[0].text, llm_request.model
)
])

View File

@@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tool_context import ToolContext
def exit_loop(tool_context: ToolContext):
"""Exits the loop.
Call this function only when you are instructed to do so.
"""
tool_context.actions.escalate = True

View File

@@ -0,0 +1,307 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import logging
import types as typing_types
from typing import _GenericAlias
from typing import Any
from typing import get_args
from typing import get_origin
from typing import Literal
from typing import Union
from google.genai import types
import pydantic
_py_builtin_type_to_schema_type = {
str: types.Type.STRING,
int: types.Type.INTEGER,
float: types.Type.NUMBER,
bool: types.Type.BOOLEAN,
list: types.Type.ARRAY,
dict: types.Type.OBJECT,
}
logger = logging.getLogger(__name__)
def _is_builtin_primitive_or_compound(
annotation: inspect.Parameter.annotation,
) -> bool:
return annotation in _py_builtin_type_to_schema_type.keys()
def _raise_for_any_of_if_mldev(schema: types.Schema):
if schema.any_of:
raise ValueError(
'AnyOf is not supported in function declaration schema for Google AI.'
)
def _update_for_default_if_mldev(schema: types.Schema):
if schema.default is not None:
# TODO(kech): Remove this walkaround once mldev supports default value.
schema.default = None
logger.warning(
'Default value is not supported in function declaration schema for'
' Google AI.'
)
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
if not variant == 'VERTEX_AI':
_raise_for_any_of_if_mldev(schema)
_update_for_default_if_mldev(schema)
def _is_default_value_compatible(
default_value: Any, annotation: inspect.Parameter.annotation
) -> bool:
# None type is expected to be handled external to this function
if _is_builtin_primitive_or_compound(annotation):
return isinstance(default_value, annotation)
if (
isinstance(annotation, _GenericAlias)
or isinstance(annotation, typing_types.GenericAlias)
or isinstance(annotation, typing_types.UnionType)
):
origin = get_origin(annotation)
if origin in (Union, typing_types.UnionType):
return any(
_is_default_value_compatible(default_value, arg)
for arg in get_args(annotation)
)
if origin is dict:
return isinstance(default_value, dict)
if origin is list:
if not isinstance(default_value, list):
return False
# most tricky case, element in list is union type
# need to apply any logic within all
# see test case test_generic_alias_complex_array_with_default_value
# a: typing.List[int | str | float | bool]
# default_value: [1, 'a', 1.1, True]
return all(
any(
_is_default_value_compatible(item, arg)
for arg in get_args(annotation)
)
for item in default_value
)
if origin is Literal:
return default_value in get_args(annotation)
# return False for any other unrecognized annotation
# let caller handle the raise
return False
def _parse_schema_from_parameter(
variant: str, param: inspect.Parameter, func_name: str
) -> types.Schema:
"""parse schema from parameter.
from the simplest case to the most complex case.
"""
schema = types.Schema()
default_value_error_msg = (
f'Default value {param.default} of parameter {param} of function'
f' {func_name} is not compatible with the parameter annotation'
f' {param.annotation}.'
)
if _is_builtin_primitive_or_compound(param.annotation):
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
schema.type = _py_builtin_type_to_schema_type[param.annotation]
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
# only parse simple UnionType, example int | str | float | bool
# complex types.UnionType will be invoked in raise branch
and all(
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
for arg in get_args(param.annotation)
)
):
schema.type = types.Type.OBJECT
schema.any_of = []
unique_types = set()
for arg in get_args(param.annotation):
if arg.__name__ == 'NoneType': # Optional type
schema.nullable = True
continue
schema_in_any_of = _parse_schema_from_parameter(
variant,
inspect.Parameter(
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
),
func_name,
)
if (
schema_in_any_of.model_dump_json(exclude_none=True)
not in unique_types
):
schema.any_of.append(schema_in_any_of)
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
if len(schema.any_of) == 1: # param: list | None -> Array
schema.type = schema.any_of[0].type
schema.any_of = None
if (
param.default is not inspect.Parameter.empty
and param.default is not None
):
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, typing_types.GenericAlias
):
origin = get_origin(param.annotation)
args = get_args(param.annotation)
if origin is dict:
schema.type = types.Type.OBJECT
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Literal:
if not all(isinstance(arg, str) for arg in args):
raise ValueError(
f'Literal type {param.annotation} must be a list of strings.'
)
schema.type = types.Type.STRING
schema.enum = list(args)
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is list:
schema.type = types.Type.ARRAY
schema.items = _parse_schema_from_parameter(
variant,
inspect.Parameter(
'item',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=args[0],
),
func_name,
)
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Union:
schema.any_of = []
schema.type = types.Type.OBJECT
unique_types = set()
for arg in args:
if arg.__name__ == 'NoneType': # Optional type
schema.nullable = True
continue
schema_in_any_of = _parse_schema_from_parameter(
variant,
inspect.Parameter(
'item',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=arg,
),
func_name,
)
if (
len(param.annotation.__args__) == 2
and type(None) in param.annotation.__args__
): # Optional type
for optional_arg in param.annotation.__args__:
if (
hasattr(optional_arg, '__origin__')
and optional_arg.__origin__ is list
):
# Optional type with list, for example Optional[list[str]]
schema.items = schema_in_any_of.items
if (
schema_in_any_of.model_dump_json(exclude_none=True)
not in unique_types
):
schema.any_of.append(schema_in_any_of)
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
schema.type = schema.any_of[0].type
schema.any_of = None
if (
param.default is not None
and param.default is not inspect.Parameter.empty
):
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
# all other generic alias will be invoked in raise branch
if (
inspect.isclass(param.annotation)
# for user defined class, we only support pydantic model
and issubclass(param.annotation, pydantic.BaseModel)
):
if (
param.default is not inspect.Parameter.empty
and param.default is not None
):
schema.default = param.default
schema.type = types.Type.OBJECT
schema.properties = {}
for field_name, field_info in param.annotation.model_fields.items():
schema.properties[field_name] = _parse_schema_from_parameter(
variant,
inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=field_info.annotation,
),
func_name,
)
_raise_if_schema_unsupported(variant, schema)
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling.Automatic function calling works best with'
' simpler function signature schema,consider manually parse your'
f' function declaration for function {func_name}.'
)
def _get_required_fields(schema: types.Schema) -> list[str]:
if not schema.properties:
return
return [
field_name
for field_name, field_schema in schema.properties.items()
if not field_schema.nullable and field_schema.default is None
]

View File

@@ -0,0 +1,87 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any
from typing import Callable
from typing import Optional
from google.genai import types
from typing_extensions import override
from ._automatic_function_calling_util import build_function_declaration
from .base_tool import BaseTool
from .tool_context import ToolContext
class FunctionTool(BaseTool):
"""A tool that wraps a user-defined Python function.
Attributes:
func: The function to wrap.
"""
def __init__(self, func: Callable[..., Any]):
super().__init__(name=func.__name__, description=func.__doc__)
self.func = func
@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(
func=self.func,
# The model doesn't understand the function context.
# input_stream is for streaming tool
ignore_params=['tool_context', 'input_stream'],
variant=self._api_variant,
)
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
args_to_call = args.copy()
signature = inspect.signature(self.func)
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {}
else:
return self.func(**args_to_call) or {}
# TODO(hangfei): fix call live for function stream.
async def _call_live(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
invocation_context,
) -> Any:
args_to_call = args.copy()
signature = inspect.signature(self.func)
if (
self.name in invocation_context.active_streaming_tools
and invocation_context.active_streaming_tools[self.name].stream
):
args_to_call['input_stream'] = invocation_context.active_streaming_tools[
self.name
].stream
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call):
yield item

View File

@@ -0,0 +1,28 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from .long_running_tool import LongRunningFunctionTool
from .tool_context import ToolContext
def get_user_choice(
options: list[str], tool_context: ToolContext
) -> Optional[str]:
"""Provides the options to the user and asks them to choose one."""
tool_context.actions.skip_summarization = True
return None
get_user_choice_tool = LongRunningFunctionTool(func=get_user_choice)

View File

@@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .google_api_tool_sets import calendar_tool_set

View File

@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from ...auth import AuthCredential
from ...auth import AuthCredentialTypes
from ...auth import OAuth2Auth
from .. import BaseTool
from ..openapi_tool import RestApiTool
from ..tool_context import ToolContext
class GoogleApiTool(BaseTool):
def __init__(self, rest_api_tool: RestApiTool):
super().__init__(
name=rest_api_tool.name,
description=rest_api_tool.description,
is_long_running=rest_api_tool.is_long_running,
)
self.rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
return self.rest_api_tool._get_declaration()
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
return await self.rest_api_tool.run_async(
args=args, tool_context=tool_context
)
def configure_auth(self, client_id: str, client_secret: str):
self.rest_api_tool.auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id=client_id,
client_secret=client_secret,
),
)

View File

@@ -0,0 +1,107 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Optional
from typing import Type
from ...auth import OpenIdConnectWithConfig
from ..openapi_tool import OpenAPIToolset
from ..openapi_tool import RestApiTool
from .google_api_tool import GoogleApiTool
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolSet:
def __init__(self, tools: List[RestApiTool]):
self.tools: Final[List[GoogleApiTool]] = [
GoogleApiTool(tool) for tool in tools
]
def get_tools(self) -> List[GoogleApiTool]:
"""Get all tools in the toolset."""
return self.tools
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
return next(matching_tool, None)
@staticmethod
def _load_tool_set_with_oidc_auth(
spec_file: str = None,
spec_dict: Dict[str, Any] = None,
scopes: list[str] = None,
) -> Optional[OpenAPIToolset]:
spec_str = None
if spec_file:
# Get the frame of the caller
caller_frame = inspect.stack()[1]
# Get the filename of the caller
caller_filename = caller_frame.filename
# Get the directory of the caller
caller_dir = os.path.dirname(os.path.abspath(caller_filename))
# Join the directory path with the filename
yaml_path = os.path.join(caller_dir, spec_file)
with open(yaml_path, 'r', encoding='utf-8') as file:
spec_str = file.read()
tool_set = OpenAPIToolset(
spec_dict=spec_dict,
spec_str=spec_str,
spec_str_type='yaml',
auth_scheme=OpenIdConnectWithConfig(
authorization_endpoint=(
'https://accounts.google.com/o/oauth2/v2/auth'
),
token_endpoint='https://oauth2.googleapis.com/token',
userinfo_endpoint=(
'https://openidconnect.googleapis.com/v1/userinfo'
),
revocation_endpoint='https://oauth2.googleapis.com/revoke',
token_endpoint_auth_methods_supported=[
'client_secret_post',
'client_secret_basic',
],
grant_types_supported=['authorization_code'],
scopes=scopes,
),
)
return tool_set
def configure_auth(self, client_id: str, client_secret: str):
for tool in self.tools:
tool.configure_auth(client_id, client_secret)
@classmethod
def load_tool_set(
cl: Type['GoogleApiToolSet'],
api_name: str,
api_version: str,
) -> 'GoogleApiToolSet':
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][
'authorizationCode'
]['scopes'].keys()
)[0]
return cl(
cl._load_tool_set_with_oidc_auth(
spec_dict=spec_dict, scopes=[scope]
).get_tools()
)

View File

@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .google_api_tool_set import GoogleApiToolSet
logger = logging.getLogger(__name__)
calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)

View File

@@ -0,0 +1,521 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
# Google API client
from googleapiclient.discovery import build
from googleapiclient.discovery import Resource
from googleapiclient.errors import HttpError
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
class GoogleApiToOpenApiConverter:
"""Converts Google API Discovery documents to OpenAPI v3 format."""
def __init__(self, api_name: str, api_version: str):
"""Initialize the converter with the API name and version.
Args:
api_name: The name of the Google API (e.g., "calendar")
api_version: The version of the API (e.g., "v3")
"""
self.api_name = api_name
self.api_version = api_version
self.google_api_resource = None
self.google_api_spec = None
self.openapi_spec = {
"openapi": "3.0.0",
"info": {},
"servers": [],
"paths": {},
"components": {"schemas": {}, "securitySchemes": {}},
}
def fetch_google_api_spec(self) -> None:
"""Fetches the Google API specification using discovery service."""
try:
logger.info(
"Fetching Google API spec for %s %s", self.api_name, self.api_version
)
# Build a resource object for the specified API
self.google_api_resource = build(self.api_name, self.api_version)
# Access the underlying API discovery document
self.google_api_spec = self.google_api_resource._rootDesc
if not self.google_api_spec:
raise ValueError("Failed to retrieve API specification")
logger.info("Successfully fetched %s API specification", self.api_name)
except HttpError as e:
logger.error("HTTP Error: %s", e)
raise
except Exception as e:
logger.error("Error fetching API spec: %s", e)
raise
def convert(self) -> Dict[str, Any]:
"""Convert the Google API spec to OpenAPI v3 format.
Returns:
Dict containing the converted OpenAPI v3 specification
"""
if not self.google_api_spec:
self.fetch_google_api_spec()
# Convert basic API information
self._convert_info()
# Convert server information
self._convert_servers()
# Convert authentication/authorization schemes
self._convert_security_schemes()
# Convert schemas (models)
self._convert_schemas()
# Convert endpoints/paths
self._convert_resources(self.google_api_spec.get("resources", {}))
# Convert top-level methods, if any
self._convert_methods(self.google_api_spec.get("methods", {}), "/")
return self.openapi_spec
def _convert_info(self) -> None:
"""Convert basic API information."""
self.openapi_spec["info"] = {
"title": self.google_api_spec.get("title", f"{self.api_name} API"),
"description": self.google_api_spec.get("description", ""),
"version": self.google_api_spec.get("version", self.api_version),
"contact": {},
"termsOfService": self.google_api_spec.get("documentationLink", ""),
}
# Add documentation links if available
docs_link = self.google_api_spec.get("documentationLink")
if docs_link:
self.openapi_spec["externalDocs"] = {
"description": "API Documentation",
"url": docs_link,
}
def _convert_servers(self) -> None:
"""Convert server information."""
base_url = self.google_api_spec.get(
"rootUrl", ""
) + self.google_api_spec.get("servicePath", "")
# Remove trailing slash if present
if base_url.endswith("/"):
base_url = base_url[:-1]
self.openapi_spec["servers"] = [{
"url": base_url,
"description": f"{self.api_name} {self.api_version} API",
}]
def _convert_security_schemes(self) -> None:
"""Convert authentication and authorization schemes."""
auth = self.google_api_spec.get("auth", {})
oauth2 = auth.get("oauth2", {})
if oauth2:
# Handle OAuth2
scopes = oauth2.get("scopes", {})
formatted_scopes = {}
for scope, scope_info in scopes.items():
formatted_scopes[scope] = scope_info.get("description", "")
self.openapi_spec["components"]["securitySchemes"]["oauth2"] = {
"type": "oauth2",
"description": "OAuth 2.0 authentication",
"flows": {
"authorizationCode": {
"authorizationUrl": (
"https://accounts.google.com/o/oauth2/auth"
),
"tokenUrl": "https://oauth2.googleapis.com/token",
"scopes": formatted_scopes,
}
},
}
# Add API key authentication (most Google APIs support this)
self.openapi_spec["components"]["securitySchemes"]["apiKey"] = {
"type": "apiKey",
"in": "query",
"name": "key",
"description": "API key for accessing this API",
}
# Create global security requirement
self.openapi_spec["security"] = [
{"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
{"apiKey": []},
]
def _convert_schemas(self) -> None:
"""Convert schema definitions (models)."""
schemas = self.google_api_spec.get("schemas", {})
for schema_name, schema_def in schemas.items():
converted_schema = self._convert_schema_object(schema_def)
self.openapi_spec["components"]["schemas"][schema_name] = converted_schema
def _convert_schema_object(
self, schema_def: Dict[str, Any]
) -> Dict[str, Any]:
"""Recursively convert a Google API schema object to OpenAPI schema.
Args:
schema_def: Google API schema definition
Returns:
Converted OpenAPI schema object
"""
result = {}
# Convert the type
if "type" in schema_def:
gtype = schema_def["type"]
if gtype == "object":
result["type"] = "object"
# Handle properties
if "properties" in schema_def:
result["properties"] = {}
for prop_name, prop_def in schema_def["properties"].items():
result["properties"][prop_name] = self._convert_schema_object(
prop_def
)
# Handle required fields
required_fields = []
for prop_name, prop_def in schema_def.get("properties", {}).items():
if prop_def.get("required", False):
required_fields.append(prop_name)
if required_fields:
result["required"] = required_fields
elif gtype == "array":
result["type"] = "array"
if "items" in schema_def:
result["items"] = self._convert_schema_object(schema_def["items"])
elif gtype == "any":
# OpenAPI doesn't have direct "any" type
# Use oneOf with multiple options as alternative
result["oneOf"] = [
{"type": "object"},
{"type": "array"},
{"type": "string"},
{"type": "number"},
{"type": "boolean"},
{"type": "null"},
]
else:
# Handle other primitive types
result["type"] = gtype
# Handle references
if "$ref" in schema_def:
ref = schema_def["$ref"]
# Google refs use "#" at start, OpenAPI uses "#/components/schemas/"
if ref.startswith("#"):
ref = ref.replace("#", "#/components/schemas/")
else:
ref = "#/components/schemas/" + ref
result["$ref"] = ref
# Handle format
if "format" in schema_def:
result["format"] = schema_def["format"]
# Handle enum values
if "enum" in schema_def:
result["enum"] = schema_def["enum"]
# Handle description
if "description" in schema_def:
result["description"] = schema_def["description"]
# Handle pattern
if "pattern" in schema_def:
result["pattern"] = schema_def["pattern"]
# Handle default value
if "default" in schema_def:
result["default"] = schema_def["default"]
return result
def _convert_resources(
self, resources: Dict[str, Any], parent_path: str = ""
) -> None:
"""Recursively convert all resources and their methods.
Args:
resources: Dictionary of resources from the Google API spec
parent_path: The parent path prefix for nested resources
"""
for resource_name, resource_data in resources.items():
# Process methods for this resource
resource_path = f"{parent_path}/{resource_name}"
methods = resource_data.get("methods", {})
self._convert_methods(methods, resource_path)
# Process nested resources recursively
nested_resources = resource_data.get("resources", {})
if nested_resources:
self._convert_resources(nested_resources, resource_path)
def _convert_methods(
self, methods: Dict[str, Any], resource_path: str
) -> None:
"""Convert methods for a specific resource path.
Args:
methods: Dictionary of methods from the Google API spec
resource_path: The path of the resource these methods belong to
"""
for method_name, method_data in methods.items():
http_method = method_data.get("httpMethod", "GET").lower()
# Determine the actual endpoint path
# Google often has the format something like 'users.messages.list'
rest_path = method_data.get("path", "/")
if not rest_path.startswith("/"):
rest_path = "/" + rest_path
path_params = self._extract_path_parameters(rest_path)
# Create path entry if it doesn't exist
if rest_path not in self.openapi_spec["paths"]:
self.openapi_spec["paths"][rest_path] = {}
# Add the operation for this method
self.openapi_spec["paths"][rest_path][http_method] = (
self._convert_operation(method_data, path_params)
)
def _extract_path_parameters(self, path: str) -> List[str]:
"""Extract path parameters from a URL path.
Args:
path: The URL path with path parameters
Returns:
List of parameter names
"""
params = []
segments = path.split("/")
for segment in segments:
# Google APIs often use {param} format for path parameters
if segment.startswith("{") and segment.endswith("}"):
param_name = segment[1:-1]
params.append(param_name)
return params
def _convert_operation(
self, method_data: Dict[str, Any], path_params: List[str]
) -> Dict[str, Any]:
"""Convert a Google API method to an OpenAPI operation.
Args:
method_data: Google API method data
path_params: List of path parameter names
Returns:
OpenAPI operation object
"""
operation = {
"operationId": method_data.get("id", ""),
"summary": method_data.get("description", ""),
"description": method_data.get("description", ""),
"parameters": [],
"responses": {
"200": {"description": "Successful operation"},
"400": {"description": "Bad request"},
"401": {"description": "Unauthorized"},
"403": {"description": "Forbidden"},
"404": {"description": "Not found"},
"500": {"description": "Server error"},
},
}
# Add path parameters
for param_name in path_params:
param = {
"name": param_name,
"in": "path",
"required": True,
"schema": {"type": "string"},
}
operation["parameters"].append(param)
# Add query parameters
for param_name, param_data in method_data.get("parameters", {}).items():
# Skip parameters already included in path
if param_name in path_params:
continue
param = {
"name": param_name,
"in": "query",
"description": param_data.get("description", ""),
"required": param_data.get("required", False),
"schema": self._convert_parameter_schema(param_data),
}
operation["parameters"].append(param)
# Handle request body
if "request" in method_data:
request_ref = method_data.get("request", {}).get("$ref", "")
if request_ref:
if request_ref.startswith("#"):
# Convert Google's reference format to OpenAPI format
openapi_ref = request_ref.replace("#", "#/components/schemas/")
else:
openapi_ref = "#/components/schemas/" + request_ref
operation["requestBody"] = {
"description": "Request body",
"content": {"application/json": {"schema": {"$ref": openapi_ref}}},
"required": True,
}
# Handle response body
if "response" in method_data:
response_ref = method_data.get("response", {}).get("$ref", "")
if response_ref:
if response_ref.startswith("#"):
# Convert Google's reference format to OpenAPI format
openapi_ref = response_ref.replace("#", "#/components/schemas/")
else:
openapi_ref = "#/components/schemas/" + response_ref
operation["responses"]["200"]["content"] = {
"application/json": {"schema": {"$ref": openapi_ref}}
}
# Add scopes if available
scopes = method_data.get("scopes", [])
if scopes:
# Add method-specific security requirement if different from global
operation["security"] = [{"oauth2": scopes}]
return operation
def _convert_parameter_schema(
self, param_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Convert a parameter definition to an OpenAPI schema.
Args:
param_data: Google API parameter data
Returns:
OpenAPI schema for the parameter
"""
schema = {}
# Convert type
param_type = param_data.get("type", "string")
schema["type"] = param_type
# Handle enum values
if "enum" in param_data:
schema["enum"] = param_data["enum"]
# Handle format
if "format" in param_data:
schema["format"] = param_data["format"]
# Handle default value
if "default" in param_data:
schema["default"] = param_data["default"]
# Handle pattern
if "pattern" in param_data:
schema["pattern"] = param_data["pattern"]
return schema
def save_openapi_spec(self, output_path: str) -> None:
"""Save the OpenAPI specification to a file.
Args:
output_path: Path where the OpenAPI spec should be saved
"""
with open(output_path, "w", encoding="utf-8") as f:
json.dump(self.openapi_spec, f, indent=2)
logger.info("OpenAPI specification saved to %s", output_path)
def main():
"""Command line interface for the converter."""
parser = argparse.ArgumentParser(
description=(
"Convert Google API Discovery documents to OpenAPI v3 specifications"
)
)
parser.add_argument(
"api_name", help="Name of the Google API (e.g., 'calendar')"
)
parser.add_argument("api_version", help="Version of the API (e.g., 'v3')")
parser.add_argument(
"--output",
"-o",
default="openapi_spec.json",
help="Output file path for the OpenAPI specification",
)
args = parser.parse_args()
try:
# Create and run the converter
converter = GoogleApiToOpenApiConverter(args.api_name, args.api_version)
converter.convert()
converter.save_openapi_spec(args.output)
print(
f"Successfully converted {args.api_name} {args.api_version} to"
" OpenAPI v3"
)
print(f"Output saved to {args.output}")
except Exception as e:
logger.error("Conversion failed: %s", e)
return 1
return 0
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
class GoogleSearchTool(BaseTool):
"""A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from Google Search.
This tool operates internally within the model and does not require or perform
local code execution.
"""
def __init__(self):
# Name and description are not used because this is a model built-in tool.
super().__init__(name='google_search', description='google_search')
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
if llm_request.model and llm_request.model.startswith('gemini-1'):
if llm_request.config.tools:
print(llm_request.config.tools)
raise ValueError(
'Google search tool can not be used with other tools in Gemini 1.x.'
)
llm_request.config.tools.append(
types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
)
elif llm_request.model and llm_request.model.startswith('gemini-2'):
llm_request.config.tools.append(
types.Tool(google_search=types.GoogleSearch())
)
else:
raise ValueError(
f'Google search tool is not supported for model {llm_request.model}'
)
google_search = GoogleSearchTool()

View File

@@ -0,0 +1,86 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Callable
from google.genai import types
from pydantic import model_validator
from typing_extensions import override
from . import _automatic_function_calling_util
from .function_tool import FunctionTool
class LangchainTool(FunctionTool):
"""Use this class to wrap a langchain tool.
If the original tool name and description are not suitable, you can override
them in the constructor.
"""
tool: Any
"""The wrapped langchain tool."""
def __init__(self, tool: Any):
super().__init__(tool._run)
self.tool = tool
if tool.name:
self.name = tool.name
if tool.description:
self.description = tool.description
@model_validator(mode='before')
@classmethod
def populate_name(cls, data: Any) -> Any:
# Override this to not use function's signature name as it's
# mostly "run" or "invoke" for thir-party tools.
return data
@override
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool."""
from langchain.agents import Tool
from langchain_core.tools import BaseTool
# There are two types of tools:
# 1. BaseTool: the tool is defined in langchain.tools.
# 2. Other tools: the tool doesn't inherit any class but follow some
# conventions, like having a "run" method.
if isinstance(self.tool, BaseTool):
tool_wrapper = Tool(
name=self.name,
func=self.func,
description=self.description,
)
if self.tool.args_schema:
tool_wrapper.args_schema = self.tool.args_schema
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
False,
self.name,
self.description,
tool_wrapper.func,
tool_wrapper.args,
)
return function_declaration
else:
# Need to provide a way to override the function names and descriptions
# as the original function names are mostly ".run" and the descriptions
# may not meet users' needs.
function_declaration = (
_automatic_function_calling_util.build_function_declaration(
func=self.tool.run,
)
)
return function_declaration

View File

@@ -0,0 +1,113 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .base_tool import BaseTool
if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
from .tool_context import ToolContext
class LoadArtifactsTool(BaseTool):
"""A tool that loads the artifacts and adds them to the session."""
def __init__(self):
super().__init__(
name='load_artifacts',
description='Loads the artifacts and adds them to the session.',
)
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'artifact_names': types.Schema(
type=types.Type.ARRAY,
items=types.Schema(
type=types.Type.STRING,
),
)
},
),
)
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
artifact_names: list[str] = args.get('artifact_names', [])
return {'artifact_names': artifact_names}
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
await super().process_llm_request(
tool_context=tool_context,
llm_request=llm_request,
)
self._append_artifacts_to_llm_request(
tool_context=tool_context, llm_request=llm_request
)
def _append_artifacts_to_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
):
artifact_names = tool_context.list_artifacts()
if not artifact_names:
return
# Tell the model about the available artifacts.
llm_request.append_instructions([f"""You have a list of artifacts:
{json.dumps(artifact_names)}
When the user asks questions about any of the artifacts, you should call the
`load_artifacts` function to load the artifact. Do not generate any text other
than the function call.
"""])
# Attache the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response
if function_response and function_response.name == 'load_artifacts':
artifact_names = function_response.response['artifact_names']
for artifact_name in artifact_names:
artifact = tool_context.load_artifact(artifact_name)
llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Artifact {artifact_name} is:'
),
artifact,
],
)
)
load_artifacts_tool = LoadArtifactsTool()

View File

@@ -0,0 +1,58 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing_extensions import override
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
from ..memory.base_memory_service import MemoryResult
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
"""Loads the memory for the current user."""
response = tool_context.search_memory(query)
return response.memories
class LoadMemoryTool(FunctionTool):
"""A tool that loads the memory for the current user."""
def __init__(self):
super().__init__(load_memory)
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
await super().process_llm_request(
tool_context=tool_context, llm_request=llm_request
)
# Tell the model about the memory.
llm_request.append_instructions(["""
You have memory. You can use it to answer questions. If any questions need
you to look up the memory, you should call load_memory function with a query.
"""])
load_memory_tool = LoadMemoryTool()

View File

@@ -0,0 +1,41 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool for web browse."""
import requests
def load_web_page(url: str) -> str:
"""Fetches the content in the url and returns the text in it.
Args:
url (str): The url to browse.
Returns:
str: The text content of the url.
"""
from bs4 import BeautifulSoup
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'lxml')
text = soup.get_text(separator='\n', strip=True)
else:
text = f'Failed to fetch url: {url}'
# Split the text into lines, filtering out very short lines
# (e.g., single words or short subtitles)
return '\n'.join(line for line in text.splitlines() if len(line.split()) > 3)

View File

@@ -0,0 +1,39 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from .function_tool import FunctionTool
class LongRunningFunctionTool(FunctionTool):
"""A function tool that returns the result asynchronously.
This tool is used for long-running operations that may take a significant
amount of time to complete. The framework will call the function. Once the
function returns, the response will be returned asynchronously to the
framework which is identified by the function_call_id.
Example:
```python
tool = LongRunningFunctionTool(a_long_running_function)
```
Attributes:
is_long_running: Whether the tool is a long running operation.
"""
def __init__(self, func: Callable):
super().__init__(func)
self.is_long_running = True

View File

@@ -0,0 +1,42 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = []
try:
from .conversion_utils import adk_to_mcp_tool_type, gemini_to_json_schema
from .mcp_tool import MCPTool
from .mcp_toolset import MCPToolset
__all__.extend([
'adk_to_mcp_tool_type',
'gemini_to_json_schema',
'MCPTool',
'MCPToolset',
])
except ImportError as e:
import logging
import sys
logger = logging.getLogger(__name__)
if sys.version_info < (3, 10):
logger.warning(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
)
else:
logger.debug('MCP Tool is not installed')
logger.debug(e)

View File

@@ -0,0 +1,161 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from google.genai.types import Schema, Type
import mcp.types as mcp_types
from ..base_tool import BaseTool
def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
"""Convert a Tool in ADK into MCP tool type.
This function transforms an ADK tool definition into its equivalent
representation in the MCP (Model Control Plane) system.
Args:
tool: The ADK tool to convert. It should be an instance of a class derived
from `BaseTool`.
Returns:
An object of MCP Tool type, representing the converted tool.
Examples:
# Assuming 'my_tool' is an instance of a BaseTool derived class
mcp_tool = adk_to_mcp_tool_type(my_tool)
print(mcp_tool)
"""
tool_declaration = tool._get_declaration()
if not tool_declaration:
input_schema = {}
else:
input_schema = gemini_to_json_schema(tool._get_declaration().parameters)
return mcp_types.Tool(
name=tool.name,
description=tool.description,
inputSchema=input_schema,
)
def gemini_to_json_schema(gemini_schema: Schema) -> Dict[str, Any]:
"""Converts a Gemini Schema object into a JSON Schema dictionary.
Args:
gemini_schema: An instance of the Gemini Schema class.
Returns:
A dictionary representing the equivalent JSON Schema.
Raises:
TypeError: If the input is not an instance of the expected Schema class.
ValueError: If an invalid Gemini Type enum value is encountered.
"""
if not isinstance(gemini_schema, Schema):
raise TypeError(
f"Input must be an instance of Schema, got {type(gemini_schema)}"
)
json_schema_dict: Dict[str, Any] = {}
# Map Type
gemini_type = getattr(gemini_schema, "type", None)
if gemini_type and gemini_type != Type.TYPE_UNSPECIFIED:
json_schema_dict["type"] = gemini_type.lower()
else:
json_schema_dict["type"] = "null"
# Map Nullable
if getattr(gemini_schema, "nullable", None) == True:
json_schema_dict["nullable"] = True
# --- Map direct fields ---
direct_mappings = {
"title": "title",
"description": "description",
"default": "default",
"enum": "enum",
"format": "format",
"example": "example",
}
for gemini_key, json_key in direct_mappings.items():
value = getattr(gemini_schema, gemini_key, None)
if value is not None:
json_schema_dict[json_key] = value
# String validation
if gemini_type == Type.STRING:
str_mappings = {
"pattern": "pattern",
"min_length": "minLength",
"max_length": "maxLength",
}
for gemini_key, json_key in str_mappings.items():
value = getattr(gemini_schema, gemini_key, None)
if value is not None:
json_schema_dict[json_key] = value
# Number/Integer validation
if gemini_type in (Type.NUMBER, Type.INTEGER):
num_mappings = {
"minimum": "minimum",
"maximum": "maximum",
}
for gemini_key, json_key in num_mappings.items():
value = getattr(gemini_schema, gemini_key, None)
if value is not None:
json_schema_dict[json_key] = value
# Array validation (Recursive call for items)
if gemini_type == Type.ARRAY:
items_schema = getattr(gemini_schema, "items", None)
if items_schema is not None:
json_schema_dict["items"] = gemini_to_json_schema(items_schema)
arr_mappings = {
"min_items": "minItems",
"max_items": "maxItems",
}
for gemini_key, json_key in arr_mappings.items():
value = getattr(gemini_schema, gemini_key, None)
if value is not None:
json_schema_dict[json_key] = value
# Object validation (Recursive call for properties)
if gemini_type == Type.OBJECT:
properties_dict = getattr(gemini_schema, "properties", None)
if properties_dict is not None:
json_schema_dict["properties"] = {
prop_name: gemini_to_json_schema(prop_schema)
for prop_name, prop_schema in properties_dict.items()
}
obj_mappings = {
"required": "required",
"min_properties": "minProperties",
"max_properties": "maxProperties",
# Note: Ignoring 'property_ordering' as it's not standard JSON Schema
}
for gemini_key, json_key in obj_mappings.items():
value = getattr(gemini_schema, gemini_key, None)
if value is not None:
json_schema_dict[json_key] = value
# Map anyOf (Recursive call for subschemas)
any_of_list = getattr(gemini_schema, "any_of", None)
if any_of_list is not None:
json_schema_dict["anyOf"] = [
gemini_to_json_schema(sub_schema) for sub_schema in any_of_list
]
return json_schema_dict

View File

@@ -0,0 +1,113 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession
from mcp.types import Tool as McpBaseTool
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
" version."
) from e
else:
raise e
from ..base_tool import BaseTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from ..tool_context import ToolContext
class MCPTool(BaseTool):
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
call the tool.
"""
def __init__(
self,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
):
"""Initializes a MCPTool.
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
the MCP Tool through executing the tool from remote MCP Session.
Example:
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
Args:
mcp_tool: The MCP tool to wrap.
mcp_session: The MCP session to use to call the tool.
auth_scheme: The authentication scheme to use.
auth_credential: The authentication credential to use.
Raises:
ValueError: If mcp_tool or mcp_session is None.
"""
if mcp_tool is None:
raise ValueError("mcp_tool cannot be None")
if mcp_session is None:
raise ValueError("mcp_session cannot be None")
self.name = mcp_tool.name
self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool
self.mcp_session = mcp_session
# TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
Returns:
FunctionDeclaration: The Gemini function declaration for the tool.
"""
schema_dict = self.mcp_tool.inputSchema
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
@override
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
Args:
args: The arguments as a dict to pass to the tool.
tool_context: The tool context from upper level ADK agent.
Returns:
Any: The response from the tool.
"""
# TODO(cheliu): Support passing tool context to MCP Server.
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response

View File

@@ -0,0 +1,272 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
) from e
else:
raise e
from pydantic import BaseModel
from .mcp_tool import MCPTool
class SseServerParams(BaseModel):
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
class MCPToolset:
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
Usage:
Example 1: (using from_server helper):
```
async def load_tools():
return await MCPToolset.from_server(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
# Use the tools in an LLM agent
tools, exit_stack = await load_tools()
agent = LlmAgent(
tools=tools
)
...
await exit_stack.aclose()
```
Example 2: (using `async with`):
```
async def load_tools():
async with MCPToolset(
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
) as toolset:
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
```
Example 3: (provide AsyncExitStack):
```
async def load_tools():
async_exit_stack = AsyncExitStack()
toolset = MCPToolset(
connection_params=StdioServerParameters(...),
)
async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
...
await async_exit_stack.aclose()
```
Attributes:
connection_params: The connection parameters to the MCP server. Can be
either `StdioServerParameters` or `SseServerParams`.
exit_stack: The async exit stack to manage the connection to the MCP server.
session: The MCP session being initialized with the connection.
"""
def __init__(
self, *, connection_params: StdioServerParameters | SseServerParams
):
"""Initializes the MCPToolset.
Usage:
Example 1: (using from_server helper):
```
async def load_tools():
return await MCPToolset.from_server(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
# Use the tools in an LLM agent
tools, exit_stack = await load_tools()
agent = LlmAgent(
tools=tools
)
...
await exit_stack.aclose()
```
Example 2: (using `async with`):
```
async def load_tools():
async with MCPToolset(
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
) as toolset:
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
```
Example 3: (provide AsyncExitStack):
```
async def load_tools():
async_exit_stack = AsyncExitStack()
toolset = MCPToolset(
connection_params=StdioServerParameters(...),
)
async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
agent = LlmAgent(
...
tools=tools
)
...
await async_exit_stack.aclose()
```
Args:
connection_params: The connection parameters to the MCP server. Can be:
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
`python3`); or `SseServerParams` for a local/remote SSE server.
"""
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params
self.exit_stack = AsyncExitStack()
@classmethod
async def from_server(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
Usage:
```
async def load_tools():
tools, exit_stack = await MCPToolset.from_server(
connection_params=StdioServerParameters(
command='npx',
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
)
```
Args:
connection_params: The connection parameters to the MCP server.
async_exit_stack: The async exit stack to use. If not provided, a new
AsyncExitStack will be created.
Returns:
A tuple of the list of MCPTools and the AsyncExitStack.
- tools: The list of MCPTools.
- async_exit_stack: The AsyncExitStack used to manage the connection to
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
toolset = cls(connection_params=connection_params)
async_exit_stack = async_exit_stack or AsyncExitStack()
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
if isinstance(self.connection_params, StdioServerParameters):
client = stdio_client(self.connection_params)
elif isinstance(self.connection_params, SseServerParams):
client = sse_client(
url=self.connection_params.url,
headers=self.connection_params.headers,
timeout=self.connection_params.timeout,
sse_read_timeout=self.connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self.connection_params}'
)
transports = await self.exit_stack.enter_async_context(client)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*transports)
)
await self.session.initialize()
return self.session
async def _exit(self):
"""Closes the connection to MCP Server."""
await self.exit_stack.aclose()
async def load_tools(self) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
Returns:
A list of MCPTools imported from the MCP Server.
"""
tools_response: ListToolsResult = await self.session.list_tools()
return [
MCPTool(mcp_tool=tool, mcp_session=self.session)
for tool in tools_response.tools
]
async def __aenter__(self):
try:
await self._initialize()
return self
except Exception as e:
raise e
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self._exit()

View File

@@ -0,0 +1,21 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .openapi_spec_parser import OpenAPIToolset
from .openapi_spec_parser import RestApiTool
__all__ = [
'OpenAPIToolset',
'RestApiTool',
]

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import auth_helpers
__all__ = [
'auth_helpers',
]

View File

@@ -0,0 +1,498 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import HTTPBase
from fastapi.openapi.models import HTTPBearer
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OpenIdConnect
from fastapi.openapi.models import Schema
from pydantic import BaseModel
from pydantic import ValidationError
import requests
from ....auth.auth_credential import AuthCredential
from ....auth.auth_credential import AuthCredentialTypes
from ....auth.auth_credential import HttpAuth
from ....auth.auth_credential import HttpCredentials
from ....auth.auth_credential import OAuth2Auth
from ....auth.auth_credential import ServiceAccount
from ....auth.auth_credential import ServiceAccountCredential
from ....auth.auth_schemes import AuthScheme
from ....auth.auth_schemes import AuthSchemeType
from ....auth.auth_schemes import OpenIdConnectWithConfig
from ..common.common import ApiParameter
class OpenIdConfig(BaseModel):
"""Represents OpenID Connect configuration.
Attributes:
client_id: The client ID.
auth_uri: The authorization URI.
token_uri: The token URI.
client_secret: The client secret.
Example:
config = OpenIdConfig(
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
client_secret="your_client_secret",
redirect
)
"""
client_id: str
auth_uri: str
token_uri: str
client_secret: str
redirect_uri: Optional[str]
def token_to_scheme_credential(
token_type: Literal["apikey", "oauth2Token"],
location: Optional[Literal["header", "query", "cookie"]] = None,
name: Optional[str] = None,
credential_value: Optional[str] = None,
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates a AuthScheme and AuthCredential for API key or bearer token.
Examples:
```
# API Key in header
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
"X-API-Key", "your_api_key_value")
# API Key in query parameter
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
"api_key", "your_api_key_value")
# OAuth2 Bearer Token in Authorization header
auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
"header", "Authorization", "your_bearer_token_value")
```
Args:
type: 'apikey' or 'oauth2Token'.
location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
name: The name of the header, query parameter, or cookie.
credential_value: The value of the API Key/ Token.
Returns:
Tuple: (AuthScheme, AuthCredential)
Raises:
ValueError: For invalid type or location.
"""
if token_type == "apikey":
in_: APIKeyIn
if location == "header":
in_ = APIKeyIn.header
elif location == "query":
in_ = APIKeyIn.query
elif location == "cookie":
in_ = APIKeyIn.cookie
else:
raise ValueError(f"Invalid location for apiKey: {location}")
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": in_,
"name": name,
})
if credential_value:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
)
else:
auth_credential = None
return auth_scheme, auth_credential
elif token_type == "oauth2Token":
# ignore location. OAuth2 Bearer Token is always in Authorization header.
auth_scheme = HTTPBearer(
bearerFormat="JWT"
) # Common format, can be omitted.
if credential_value:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credential_value),
),
)
else:
auth_credential = None
return auth_scheme, auth_credential
else:
raise ValueError(f"Invalid security scheme type: {type}")
def service_account_dict_to_scheme_credential(
config: Dict[str, Any],
scopes: List[str],
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates AuthScheme and AuthCredential for Google Service Account.
Returns a bearer token scheme, and a service account credential.
Args:
config: A ServiceAccount object containing the Google Service Account
configuration.
scopes: A list of scopes to be used.
Returns:
Tuple: (AuthScheme, AuthCredential)
"""
auth_scheme = HTTPBearer(bearerFormat="JWT")
service_account = ServiceAccount(
service_account_credential=ServiceAccountCredential.model_construct(
**config
),
scopes=scopes,
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
)
return auth_scheme, auth_credential
def service_account_scheme_credential(
config: ServiceAccount,
) -> Tuple[AuthScheme, AuthCredential]:
"""Creates AuthScheme and AuthCredential for Google Service Account.
Returns a bearer token scheme, and a service account credential.
Args:
config: A ServiceAccount object containing the Google Service Account
configuration.
Returns:
Tuple: (AuthScheme, AuthCredential)
"""
auth_scheme = HTTPBearer(bearerFormat="JWT")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
)
return auth_scheme, auth_credential
def openid_dict_to_scheme_credential(
config_dict: Dict[str, Any],
scopes: List[str],
credential_dict: Dict[str, Any],
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
"""Constructs OpenID scheme and credential from configuration and credential dictionaries.
Args:
config_dict: Dictionary containing OpenID Connect configuration, must
include at least 'authorization_endpoint' and 'token_endpoint'.
scopes: List of scopes to be used.
credential_dict: Dictionary containing credential information, must
include 'client_id', 'client_secret', and 'scopes'. May optionally
include 'redirect_uri'.
Returns:
Tuple: (OpenIdConnectWithConfig, AuthCredential)
Raises:
ValueError: If required fields are missing in the input dictionaries.
"""
# Validate and create the OpenIdConnectWithConfig scheme
try:
config_dict["scopes"] = scopes
# If user provides the OpenID Config as a static dict, it may not contain
# openIdConnect URL.
if "openIdConnectUrl" not in config_dict:
config_dict["openIdConnectUrl"] = ""
openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
except ValidationError as e:
raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
# Attempt to adjust credential_dict if this is a key downloaded from Google
# OAuth config
if len(list(credential_dict.values())) == 1:
credential_value = list(credential_dict.values())[0]
if "client_id" in credential_value and "client_secret" in credential_value:
credential_dict = credential_value
# Validate credential_dict
required_credential_fields = ["client_id", "client_secret"]
missing_fields = [
field
for field in required_credential_fields
if field not in credential_dict
]
if missing_fields:
raise ValueError(
"Missing required fields in credential_dict:"
f" {', '.join(missing_fields)}"
)
# Construct AuthCredential
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id=credential_dict["client_id"],
client_secret=credential_dict["client_secret"],
redirect_uri=credential_dict.get("redirect_uri", None),
),
)
return openid_scheme, auth_credential
def openid_url_to_scheme_credential(
openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
"""Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
Fetches OpenID configuration from the provided URL.
Args:
openid_url: The OpenID Connect discovery URL.
scopes: List of scopes to be used.
credential_dict: Dictionary containing credential information, must
include at least "client_id" and "client_secret", may optionally include
"redirect_uri" and "scope"
Returns:
Tuple: (AuthScheme, AuthCredential)
Raises:
ValueError: If the OpenID URL is invalid, fetching fails, or required
fields are missing.
requests.exceptions.RequestException: If there's an error during the
HTTP request.
"""
try:
response = requests.get(openid_url, timeout=10)
response.raise_for_status()
config_dict = response.json()
except requests.exceptions.RequestException as e:
raise ValueError(
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
) from e
except ValueError as e:
raise ValueError(
"Invalid JSON response from OpenID configuration endpoint"
f" {openid_url}: {e}"
) from e
# Add openIdConnectUrl to config dict
config_dict["openIdConnectUrl"] = openid_url
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
def credential_to_param(
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
"""Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
This function now supports all credential types returned by the exchangers:
- API Key
- HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
- OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
token)
- Service Account (returns None, None, as the token is now a Bearer token)
Args:
auth_scheme: The AuthScheme object.
auth_credential: The AuthCredential object.
Returns:
Tuple: (ApiParameter, Dict[str, Any])
"""
if not auth_credential:
return None, None
if (
auth_scheme.type_ == AuthSchemeType.apiKey
and auth_credential
and auth_credential.api_key
):
param_name = auth_scheme.name or ""
python_name = INTERNAL_AUTH_PREFIX + param_name
if auth_scheme.in_ == APIKeyIn.header:
param_location = "header"
elif auth_scheme.in_ == APIKeyIn.query:
param_location = "query"
elif auth_scheme.in_ == APIKeyIn.cookie:
param_location = "cookie"
else:
raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
param = ApiParameter(
original_name=param_name,
param_location=param_location,
param_schema=Schema(type="string"),
description=auth_scheme.description or "",
py_name=python_name,
)
kwargs = {param.py_name: auth_credential.api_key}
return param, kwargs
# TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
# Scheme
elif (
auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
):
if (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
param = ApiParameter(
original_name="Authorization",
param_location="header",
param_schema=Schema(type="string"),
description=auth_scheme.description or "Bearer token",
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
)
kwargs = {
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
}
return param, kwargs
elif (
auth_credential
and auth_credential.http
and auth_credential.http.credentials
and (
auth_credential.http.credentials.username
or auth_credential.http.credentials.password
)
):
# Basic Auth is explicitly NOT supported
raise NotImplementedError("Basic Authentication is not supported.")
else:
raise ValueError("Invalid HTTP auth credentials")
# Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
# Bearer tokens.
elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
):
if (
auth_credential.http
and auth_credential.http.credentials
and auth_credential.http.credentials.token
):
param = ApiParameter(
original_name="Authorization",
param_location="header",
param_schema=Schema(type="string"),
description=auth_scheme.description or "Bearer token",
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
)
kwargs = {
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
}
return param, kwargs
return None, None
else:
raise ValueError("Invalid security scheme and credential combination")
def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
"""Converts a dictionary to a FastAPI AuthScheme object.
Args:
data: The dictionary representing the security scheme.
Returns:
A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
HTTPBearer).
Raises:
ValueError: If the 'type' field is missing or invalid, or if the
dictionary cannot be converted to the corresponding Pydantic model.
Example:
```python
api_key_data = {
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
}
api_key_scheme = dict_to_auth_scheme(api_key_data)
bearer_data = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
}
bearer_scheme = dict_to_auth_scheme(bearer_data)
oauth2_data = {
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://example.com/auth",
"tokenUrl": "https://example.com/token",
}
}
}
oauth2_scheme = dict_to_auth_scheme(oauth2_data)
openid_data = {
"type": "openIdConnect",
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
}
openid_scheme = dict_to_auth_scheme(openid_data)
```
"""
if "type" not in data:
raise ValueError("Missing 'type' field in security scheme dictionary.")
security_type = data["type"]
try:
if security_type == "apiKey":
return APIKey.model_validate(data)
elif security_type == "http":
if data.get("scheme") == "bearer":
return HTTPBearer.model_validate(data)
else:
return HTTPBase.model_validate(data) # Generic HTTP
elif security_type == "oauth2":
return OAuth2.model_validate(data)
elif security_type == "openIdConnect":
return OpenIdConnect.model_validate(data)
else:
raise ValueError(f"Invalid security scheme type: {security_type}")
except ValidationError as e:
raise ValueError(f"Invalid security scheme data: {e}") from e

View File

@@ -0,0 +1,25 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from .base_credential_exchanger import BaseAuthCredentialExchanger
from .oauth2_exchanger import OAuth2CredentialExchanger
from .service_account_exchanger import ServiceAccountCredentialExchanger
__all__ = [
'AutoAuthCredentialExchanger',
'BaseAuthCredentialExchanger',
'OAuth2CredentialExchanger',
'ServiceAccountCredentialExchanger',
]

View File

@@ -0,0 +1,105 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional
from typing import Type
from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_schemes import AuthScheme
from .base_credential_exchanger import BaseAuthCredentialExchanger
from .oauth2_exchanger import OAuth2CredentialExchanger
from .service_account_exchanger import ServiceAccountCredentialExchanger
class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
"""Automatically selects the appropriate credential exchanger based on the auth scheme.
Optionally, an override can be provided to use a specific exchanger for a
given auth scheme.
Example (common case):
```
exchanger = AutoAuthCredentialExchanger()
auth_credential = exchanger.exchange_credential(
auth_scheme=service_account_scheme,
auth_credential=service_account_credential,
)
# Returns an oauth token in the form of a bearer token.
```
Example (use CustomAuthExchanger for OAuth2):
```
exchanger = AutoAuthCredentialExchanger(
custom_exchangers={
AuthScheme.OAUTH2: CustomAuthExchanger,
}
)
```
Attributes:
exchangers: A dictionary mapping auth scheme to credential exchanger class.
"""
def __init__(
self,
custom_exchangers: Optional[
Dict[str, Type[BaseAuthCredentialExchanger]]
] = None,
):
"""Initializes the AutoAuthCredentialExchanger.
Args:
custom_exchangers: Optional dictionary for adding or overriding auth
exchangers. The key is the auth scheme, and the value is the credential
exchanger class.
"""
self.exchangers = {
AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
}
if custom_exchangers:
self.exchangers.update(custom_exchangers)
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> Optional[AuthCredential]:
"""Automatically exchanges for the credential uses the appropriate credential exchanger.
Args:
auth_scheme (AuthScheme): The security scheme.
auth_credential (AuthCredential): Optional. The authentication
credential.
Returns: (AuthCredential)
A new AuthCredential object containing the exchanged credential.
"""
if not auth_credential:
return None
exchanger_class = self.exchangers.get(
auth_credential.auth_type if auth_credential else None
)
if not exchanger_class:
return auth_credential
exchanger = exchanger_class()
return exchanger.exchange_credential(auth_scheme, auth_credential)

View File

@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Optional
from .....auth.auth_credential import (
AuthCredential,
)
from .....auth.auth_schemes import AuthScheme
class AuthCredentialMissingError(Exception):
"""Exception raised when required authentication credentials are missing."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
class BaseAuthCredentialExchanger:
"""Base class for authentication credential exchangers."""
@abc.abstractmethod
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the provided authentication credential for a usable token/credential.
Args:
auth_scheme: The security scheme.
auth_credential: The authentication credential.
Returns:
An updated AuthCredential object containing the fetched credential.
For simple schemes like API key, it may return the original credential
if no exchange is needed.
Raises:
NotImplementedError: If the method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement exchange_credential.")

View File

@@ -0,0 +1,117 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Credential fetcher for OpenID Connect."""
from typing import Optional
from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_credential import HttpAuth
from .....auth.auth_credential import HttpCredentials
from .....auth.auth_schemes import AuthScheme
from .....auth.auth_schemes import AuthSchemeType
from .base_credential_exchanger import BaseAuthCredentialExchanger
class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
"""Fetches credentials for OAuth2 and OpenID Connect."""
def _check_scheme_credential_type(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
):
if not auth_credential:
raise ValueError(
"auth_credential is empty. Please create AuthCredential using"
" OAuth2Auth."
)
if auth_scheme.type_ not in (
AuthSchemeType.openIdConnect,
AuthSchemeType.oauth2,
):
raise ValueError(
"Invalid security scheme, expect AuthSchemeType.openIdConnect or "
f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
)
if not auth_credential.oauth2 and not auth_credential.http:
raise ValueError(
"auth_credential is not configured with oauth2. Please"
" create AuthCredential and set OAuth2Auth."
)
def generate_auth_token(
self,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Generates an auth token from the authorization response.
Args:
auth_scheme: The OpenID Connect or OAuth2 auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential object containing the HTTP bearer access token. If the
HTTO bearer token cannot be generated, return the origianl credential
"""
if "access_token" not in auth_credential.oauth2.token:
return auth_credential
# Return the access token as a bearer token.
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"]
),
),
)
return updated_credential
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the OpenID Connect auth credential for an access token or an auth URI.
Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential object containing the HTTP Bearer access token.
Raises:
ValueError: If the auth scheme or auth credential is invalid.
"""
# TODO(cheliu): Implement token refresh flow
self._check_scheme_credential_type(auth_scheme, auth_credential)
# If token is already HTTPBearer token, do nothing assuming that this token
# is valid.
if auth_credential.http:
return auth_credential
# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token:
return self.generate_auth_token(auth_credential)
return None

View File

@@ -0,0 +1,97 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Credential fetcher for Google Service Account."""
from typing import Optional
import google.auth
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import google.oauth2.credentials
from .....auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
HttpAuth,
HttpCredentials,
)
from .....auth.auth_schemes import AuthScheme
from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
"""Fetches credentials for Google Service Account.
Uses the default service credential if `use_default_credential = True`.
Otherwise, uses the service account credential provided in the auth
credential.
"""
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the service account auth credential for an access token.
If auth_credential contains a service account credential, it will be used
to fetch an access token. Otherwise, the default service credential will be
used for fetching an access token.
Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential in HTTPBearer format, containing the access token.
"""
if (
auth_credential is None
or auth_credential.service_account is None
or (
auth_credential.service_account.service_account_credential is None
and not auth_credential.service_account.use_default_credential
)
):
raise AuthCredentialMissingError(
"Service account credentials are missing. Please provide them, or set"
" `use_default_credential = True` to use application default"
" credential in a hosted service like Cloud Run."
)
try:
if auth_credential.service_account.use_default_credential:
credentials, _ = google.auth.default()
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
)
credentials.refresh(Request())
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
),
)
return updated_credential
except Exception as e:
raise AuthCredentialMissingError(
f"Failed to exchange service account token: {e}"
) from e

View File

@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import common
__all__ = [
'common',
]

View File

@@ -0,0 +1,300 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keyword
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.openapi.models import Response
from fastapi.openapi.models import Schema
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_serializer
def to_snake_case(text: str) -> str:
"""Converts a string into snake_case.
Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
(e.g., "REST API") and consecutive uppercase letters correctly. Also handles
mixed cases with and without spaces.
Examples:
```
to_snake_case('camelCase') -> 'camel_case'
to_snake_case('UpperCamelCase') -> 'upper_camel_case'
to_snake_case('space separated') -> 'space_separated'
```
Args:
text: The input string.
Returns:
The snake_case version of the string.
"""
# Handle spaces and non-alphanumeric characters (replace with underscores)
text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
# Insert underscores before uppercase letters (handling both CamelCases)
text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
text = re.sub(
r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
) # UpperCamelCase and acronyms
# Convert to lowercase
text = text.lower()
# Remove consecutive underscores (clean up extra underscores)
text = re.sub(r'_+', '_', text)
# Remove leading and trailing underscores
text = text.strip('_')
return text
def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
"""Renames Python keywords by adding a prefix.
Example:
```
rename_python_keywords('if') -> 'param_if'
rename_python_keywords('for') -> 'param_for'
```
Args:
s: The input string.
prefix: The prefix to add to the keyword.
Returns:
The renamed string.
"""
if keyword.iskeyword(s):
return prefix + s
return s
class ApiParameter(BaseModel):
"""Data class representing a function parameter."""
original_name: str
param_location: str
param_schema: Union[str, Schema]
description: Optional[str] = ''
py_name: Optional[str] = ''
type_value: type[Any] = Field(default=None, init_var=False)
type_hint: str = Field(default=None, init_var=False)
def model_post_init(self, _: Any):
self.py_name = (
self.py_name
if self.py_name
else rename_python_keywords(to_snake_case(self.original_name))
)
if isinstance(self.param_schema, str):
self.param_schema = Schema.model_validate_json(self.param_schema)
self.description = self.description or self.param_schema.description or ''
self.type_value = TypeHintHelper.get_type_value(self.param_schema)
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
return self
@model_serializer
def _serialize(self):
return {
'original_name': self.original_name,
'param_location': self.param_location,
'param_schema': self.param_schema,
'description': self.description,
'py_name': self.py_name,
}
def __str__(self):
return f'{self.py_name}: {self.type_hint}'
def to_arg_string(self):
"""Converts the parameter to an argument string for function call."""
return f'{self.py_name}={self.py_name}'
def to_dict_property(self):
"""Converts the parameter to a key:value string for dict property."""
return f'"{self.py_name}": {self.py_name}'
def to_pydoc_string(self):
"""Converts the parameter to a PyDoc parameter docstr."""
return PydocHelper.generate_param_doc(self)
class TypeHintHelper:
"""Helper class for generating type hints."""
@staticmethod
def get_type_value(schema: Schema) -> Any:
"""Generates the Python type value for a given parameter."""
param_type = schema.type if schema.type else Any
if param_type == 'integer':
return int
elif param_type == 'number':
return float
elif param_type == 'boolean':
return bool
elif param_type == 'string':
return str
elif param_type == 'array':
items_type = Any
if schema.items and schema.items.type:
items_type = schema.items.type
if items_type == 'object':
return List[Dict[str, Any]]
else:
type_map = {
'integer': int,
'number': float,
'boolean': bool,
'string': str,
'object': Dict[str, Any],
'array': List[Any],
}
return List[type_map.get(items_type, 'Any')]
elif param_type == 'object':
return Dict[str, Any]
else:
return Any
@staticmethod
def get_type_hint(schema: Schema) -> str:
"""Generates the Python type in string for a given parameter."""
param_type = schema.type if schema.type else 'Any'
if param_type == 'integer':
return 'int'
elif param_type == 'number':
return 'float'
elif param_type == 'boolean':
return 'bool'
elif param_type == 'string':
return 'str'
elif param_type == 'array':
items_type = 'Any'
if schema.items and schema.items.type:
items_type = schema.items.type
if items_type == 'object':
return 'List[Dict[str, Any]]'
else:
type_map = {
'integer': 'int',
'number': 'float',
'boolean': 'bool',
'string': 'str',
}
return f"List[{type_map.get(items_type, 'Any')}]"
elif param_type == 'object':
return 'Dict[str, Any]'
else:
return 'Any'
class PydocHelper:
"""Helper class for generating PyDoc strings."""
@staticmethod
def generate_param_doc(
param: ApiParameter,
) -> str:
"""Generates a parameter documentation string.
Args:
param: ApiParameter - The parameter to generate the documentation for.
Returns:
str: The generated parameter Python documentation string.
"""
description = param.description.strip() if param.description else ''
param_doc = f'{param.py_name} ({param.type_hint}): {description}'
if param.param_schema.type == 'object':
properties = param.param_schema.properties
if properties:
param_doc += ' Object properties:\n'
for prop_name, prop_details in properties.items():
prop_desc = prop_details.description or ''
prop_type = TypeHintHelper.get_type_hint(prop_details)
param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
return param_doc
@staticmethod
def generate_return_doc(responses: Dict[str, Response]) -> str:
"""Generates a return value documentation string.
Args:
responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
Operation
Returns:
str: The generated return value Python documentation string.
"""
return_doc = ''
# Only consider 2xx responses for return type hinting.
# Returns the 2xx response with the smallest status code number and with
# content defined.
sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
qualified_response = next(
filter(
lambda r: r[0].startswith('2') and r[1].content,
sorted_responses,
),
None,
)
if not qualified_response:
return ''
response_details = qualified_response[1]
description = (response_details.description or '').strip()
content = response_details.content or {}
# Generate return type hint and properties for the first response type.
# TODO(cheliu): Handle multiple content types.
for _, schema_details in content.items():
schema = schema_details.schema_ or {}
# Use a dummy Parameter object for return type hinting.
dummy_param = ApiParameter(
original_name='', param_location='', param_schema=schema
)
return_doc = f'Returns ({dummy_param.type_hint}): {description}'
response_type = schema.type or 'Any'
if response_type != 'object':
break
properties = schema.properties
if not properties:
break
return_doc += ' Object properties:\n'
for prop_name, prop_details in properties.items():
prop_desc = prop_details.description or ''
prop_type = TypeHintHelper.get_type_hint(prop_details)
return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
break
return return_doc

View File

@@ -0,0 +1,32 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
from .openapi_toolset import OpenAPIToolset
from .operation_parser import OperationParser
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
from .tool_auth_handler import ToolAuthHandler
__all__ = [
'OpenApiSpecParser',
'OperationEndpoint',
'ParsedOperation',
'OpenAPIToolset',
'OperationParser',
'RestApiTool',
'to_gemini_schema',
'snake_to_lower_camel',
'AuthPreparationState',
'ToolAuthHandler',
]

View File

@@ -0,0 +1,231 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from fastapi.openapi.models import Operation
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .operation_parser import OperationParser
class OperationEndpoint(BaseModel):
base_url: str
path: str
method: str
class ParsedOperation(BaseModel):
name: str
description: str
endpoint: OperationEndpoint
operation: Operation
parameters: List[ApiParameter]
return_value: ApiParameter
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
additional_context: Optional[Any] = None
class OpenApiSpecParser:
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
This class takes an OpenApiOperation object and provides methods to generate:
1. A string representation of a Python function that handles the operation.
2. A JSON schema representing the input parameters of the operation.
3. A callable Python object (a function) that can execute the operation.
"""
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
ParsedOperation objects are further used for generating RestApiTool.
Args:
openapi_spec_dict: A dictionary representing the OpenAPI specification.
Returns:
A list of ParsedOperation objects.
"""
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
operations = self._collect_operations(openapi_spec_dict)
return operations
def _collect_operations(
self, openapi_spec: Dict[str, Any]
) -> List[ParsedOperation]:
"""Collects operations from an OpenAPI spec."""
operations = []
# Taking first server url, or default to empty string if not present
base_url = ""
if openapi_spec.get("servers"):
base_url = openapi_spec["servers"][0].get("url", "")
# Get global security scheme (if any)
global_scheme_name = None
if openapi_spec.get("security"):
# Use first scheme by default.
scheme_names = list(openapi_spec["security"][0].keys())
global_scheme_name = scheme_names[0] if scheme_names else None
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
for path, path_item in openapi_spec.get("paths", {}).items():
if path_item is None:
continue
for method in (
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
"trace",
):
operation_dict = path_item.get(method)
if operation_dict is None:
continue
# If operation ID is missing, assign an operation id based on path
# and method
if "operationId" not in operation_dict:
temp_id = to_snake_case(f"{path}_{method}")
operation_dict["operationId"] = temp_id
url = OperationEndpoint(base_url=base_url, path=path, method=method)
operation = Operation.model_validate(operation_dict)
operation_parser = OperationParser(operation)
# Check for operation-specific auth scheme
auth_scheme_name = operation_parser.get_auth_scheme_name()
auth_scheme_name = (
auth_scheme_name if auth_scheme_name else global_scheme_name
)
auth_scheme = (
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
)
parsed_op = ParsedOperation(
name=operation_parser.get_function_name(),
description=operation.description or operation.summary or "",
endpoint=url,
operation=operation,
parameters=operation_parser.get_parameters(),
return_value=operation_parser.get_return_value(),
auth_scheme=auth_scheme,
auth_credential=None, # Placeholder
additional_context={},
)
operations.append(parsed_op)
return operations
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively resolves all $ref references in an OpenAPI specification.
Handles circular references correctly.
Args:
openapi_spec: A dictionary representing the OpenAPI specification.
Returns:
A dictionary representing the OpenAPI specification with all references
resolved.
"""
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
resolved_cache = {} # Cache resolved references
def resolve_ref(ref_string, current_doc):
"""Resolves a single $ref string."""
parts = ref_string.split("/")
if parts[0] != "#":
raise ValueError(f"External references not supported: {ref_string}")
current = current_doc
for part in parts[1:]:
if part in current:
current = current[part]
else:
return None # Reference not found
return current
def recursive_resolve(obj, current_doc, seen_refs=None):
"""Recursively resolves references, handling circularity.
Args:
obj: The object to traverse.
current_doc: Document to search for refs.
seen_refs: A set to track already-visited references (for circularity
detection).
Returns:
The resolved object.
"""
if seen_refs is None:
seen_refs = set() # Initialize the set if it's the first call
if isinstance(obj, dict):
if "$ref" in obj and isinstance(obj["$ref"], str):
ref_string = obj["$ref"]
# Check for circularity
if ref_string in seen_refs and ref_string not in resolved_cache:
# Circular reference detected! Return a *copy* of the object,
# but *without* the $ref. This breaks the cycle while
# still maintaining the overall structure.
return {k: v for k, v in obj.items() if k != "$ref"}
seen_refs.add(ref_string) # Add the reference to the set
# Check if we have a cached resolved value
if ref_string in resolved_cache:
return copy.deepcopy(resolved_cache[ref_string])
resolved_value = resolve_ref(ref_string, current_doc)
if resolved_value is not None:
# Recursively resolve the *resolved* value,
# passing along the 'seen_refs' set
resolved_value = recursive_resolve(
resolved_value, current_doc, seen_refs
)
resolved_cache[ref_string] = resolved_value
return copy.deepcopy(resolved_value) # return the cached result
else:
return obj # return original if no resolved value.
else:
new_dict = {}
for key, value in obj.items():
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
return new_dict
elif isinstance(obj, list):
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
else:
return obj
return recursive_resolve(openapi_spec, openapi_spec)

View File

@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Literal
from typing import Optional
import yaml
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import RestApiTool
logger = logging.getLogger(__name__)
class OpenAPIToolset:
"""Class for parsing OpenAPI spec into a list of RestApiTool.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
"""
def __init__(
self,
*,
spec_dict: Optional[Dict[str, Any]] = None,
spec_str: Optional[str] = None,
spec_str_type: Literal["json", "yaml"] = "json",
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
):
"""Initializes the OpenAPIToolset.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
Args:
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
instead of loading the spec from a string.
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
when spec_dict is not provided.
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
"yaml".
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
auth_credential: The auth credential to use for all tools. Use
AuthCredential or use helpers in
`google.adk.tools.openapi_tool.auth.auth_helpers`
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
):
"""Configure auth scheme and credential for all tools."""
for tool in self.tools:
if auth_scheme:
tool.configure_auth_scheme(auth_scheme)
if auth_credential:
tool.configure_auth_credential(auth_credential)
def get_tools(self) -> List[RestApiTool]:
"""Get all tools in the toolset."""
return self.tools
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
return next(matching_tool, None)
def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary."""
if spec_type == "json":
return json.loads(spec_str)
elif spec_type == "yaml":
return yaml.safe_load(spec_str)
else:
raise ValueError(f"Unsupported spec type: {spec_type}")
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
"""Parse OpenAPI spec into a list of RestApiTool."""
operations = OpenApiSpecParser().parse(openapi_spec_dict)
tools = []
for o in operations:
tool = RestApiTool.from_parsed_operation(o)
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools

View File

@@ -0,0 +1,260 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
class OperationParser:
"""Generates parameters for Python functions from an OpenAPI operation.
This class processes an OpenApiOperation object and provides helper methods
to extract information needed to generate Python function declarations,
docstrings, signatures, and JSON schemas. It handles parameter processing,
name deduplication, and type hint generation.
"""
def __init__(
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
):
"""Initializes the OperationParser with an OpenApiOperation.
Args:
operation: The OpenApiOperation object or a dictionary to process.
should_parse: Whether to parse the operation during initialization.
"""
if isinstance(operation, dict):
self.operation = Operation.model_validate(operation)
elif isinstance(operation, str):
self.operation = Operation.model_validate_json(operation)
else:
self.operation = operation
self.params: List[ApiParameter] = []
self.return_value: Optional[ApiParameter] = None
if should_parse:
self._process_operation_parameters()
self._process_request_body()
self._process_return_value()
self._dedupe_param_names()
@classmethod
def load(
cls,
operation: Union[Operation, Dict[str, Any]],
params: List[ApiParameter],
return_value: Optional[ApiParameter] = None,
) -> 'OperationParser':
parser = cls(operation, should_parse=False)
parser.params = params
parser.return_value = return_value
return parser
def _process_operation_parameters(self):
"""Processes parameters from the OpenAPI operation."""
parameters = self.operation.parameters or []
for param in parameters:
if isinstance(param, Parameter):
original_name = param.name
description = param.description or ''
location = param.in_ or ''
schema = param.schema_ or {} # Use schema_ instead of .schema
self.params.append(
ApiParameter(
original_name=original_name,
param_location=location,
param_schema=schema,
description=description,
)
)
def _process_request_body(self):
"""Processes the request body from the OpenAPI operation."""
request_body = self.operation.requestBody
if not request_body:
return
content = request_body.content or {}
if not content:
return
# If request body is an object, expand the properties as parameters
for _, media_type_object in content.items():
schema = media_type_object.schema_ or {}
description = request_body.description or ''
if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items():
self.params.append(
ApiParameter(
original_name=prop_name,
param_location='body',
param_schema=prop_details,
description=prop_details.description,
)
)
elif schema and schema.type == 'array':
self.params.append(
ApiParameter(
original_name='array',
param_location='body',
param_schema=schema,
description=description,
)
)
else:
self.params.append(
# Empty name for unnamed body param
ApiParameter(
original_name='',
param_location='body',
param_schema=schema,
description=description,
)
)
break # Process first mime type only
def _dedupe_param_names(self):
"""Deduplicates parameter names to avoid conflicts."""
params_cnt = {}
for param in self.params:
name = param.py_name
if name not in params_cnt:
params_cnt[name] = 0
else:
params_cnt[name] += 1
param.py_name = f'{name}_{params_cnt[name] -1}'
def _process_return_value(self) -> Parameter:
"""Returns a Parameter object representing the return type."""
responses = self.operation.responses or {}
# Default to Any if no 2xx response or if schema is missing
return_schema = Schema(type='Any')
# Take the 20x response with the smallest response code.
valid_codes = list(
filter(lambda k: k.startswith('2'), list(responses.keys()))
)
min_20x_status_code = min(valid_codes) if valid_codes else None
if min_20x_status_code and responses[min_20x_status_code].content:
content = responses[min_20x_status_code].content
for mime_type in content:
if content[mime_type].schema_:
return_schema = content[mime_type].schema_
break
self.return_value = ApiParameter(
original_name='',
param_location='',
param_schema=return_schema,
)
def get_function_name(self) -> str:
"""Returns the generated function name."""
operation_id = self.operation.operationId
if not operation_id:
raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.)."""
return self.return_value.type_hint
def get_return_type_value(self) -> Any:
"""Returns the return type value (like str, int, List[str], etc.)."""
return self.return_value.type_value
def get_parameters(self) -> List[ApiParameter]:
"""Returns the list of Parameter objects."""
return self.params
def get_return_value(self) -> ApiParameter:
"""Returns the list of Parameter objects."""
return self.return_value
def get_auth_scheme_name(self) -> str:
"""Returns the name of the auth scheme for this operation from the spec."""
if self.operation.security:
scheme_name = list(self.operation.security[0].keys())[0]
return scheme_name
return ''
def get_pydoc_string(self) -> str:
"""Returns the generated PyDoc string."""
pydoc_params = [param.to_pydoc_string() for param in self.params]
pydoc_description = (
self.operation.summary or self.operation.description or ''
)
pydoc_return = PydocHelper.generate_return_doc(
self.operation.responses or {}
)
pydoc_arg_list = chr(10).join(
f' {param_doc}' for param_doc in pydoc_params
)
return dedent(f"""
\"\"\"{pydoc_description}
Args:
{pydoc_arg_list}
{pydoc_return}
\"\"\"
""").strip()
def get_json_schema(self) -> Dict[str, Any]:
"""Returns the JSON schema for the function arguments."""
properties = {
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
for p in self.params
}
return {
'properties': properties,
'required': [p.py_name for p in self.params],
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
'type': 'object',
}
def get_signature_parameters(self) -> List[inspect.Parameter]:
"""Returns a list of inspect.Parameter objects for the function."""
return [
inspect.Parameter(
param.py_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param.type_value,
)
for param in self.params
]
def get_annotations(self) -> Dict[str, Any]:
"""Returns a dictionary of parameter annotations for the function."""
annotations = {p.py_name: p.type_value for p in self.params}
annotations['return'] = self.get_return_type_value()
return annotations

View File

@@ -0,0 +1,496 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from fastapi.openapi.models import Operation
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
import requests
from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .openapi_spec_parser import OperationEndpoint
from .openapi_spec_parser import ParsedOperation
from .operation_parser import OperationParser
from .tool_auth_handler import ToolAuthHandler
def snake_to_lower_camel(snake_case_string: str):
"""Converts a snake_case string to a lower_camel_case string.
Args:
snake_case_string: The input snake_case string.
Returns:
The lower_camel_case string.
"""
if "_" not in snake_case_string:
return snake_case_string
return "".join([
s.lower() if i == 0 else s.capitalize()
for i, s in enumerate(snake_case_string.split("_"))
])
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
Args:
openapi_schema: The OpenAPI schema dictionary.
Returns:
A Pydantic Schema object. Returns None if input is None.
Raises TypeError if input is not a dict.
"""
if openapi_schema is None:
return None
if not isinstance(openapi_schema, dict):
raise TypeError("openapi_schema must be a dictionary")
pydantic_schema_data = {}
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
if snake_case_key in Schema.model_fields:
if snake_case_key in ["title", "default", "format"]:
# Ignore these fields as Gemini backend doesn't recognize them, and will
# throw exception if they appear in the schema.
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
if snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}
elif snake_case_key == "items" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
elif snake_case_key == "any_of" and isinstance(value, list):
pydantic_schema_data[snake_case_key] = [
to_gemini_schema(item) for item in value
]
# Important: Handle cases where the OpenAPI schema might contain lists
# or other structures that need to be recursively processed.
elif isinstance(value, list) and snake_case_key not in (
"enum",
"required",
"property_ordering",
):
new_list = []
for item in value:
if isinstance(item, dict):
new_list.append(to_gemini_schema(item))
else:
new_list.append(item)
pydantic_schema_data[snake_case_key] = new_list
elif isinstance(value, dict) and snake_case_key not in ("properties"):
# Handle dictionary which is neither properties or items
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
else:
# Simple value assignment (int, str, bool, etc.)
pydantic_schema_data[snake_case_key] = value
return Schema(**pydantic_schema_data)
AuthPreparationState = Literal["pending", "done"]
class RestApiTool(BaseTool):
"""A generic tool that interacts with a REST API.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
def __init__(
self,
name: str,
description: str,
endpoint: Union[OperationEndpoint, str],
operation: Union[Operation, str],
auth_scheme: Optional[Union[AuthScheme, str]] = None,
auth_credential: Optional[Union[AuthCredential, str]] = None,
should_parse_operation=True,
):
"""Initializes the RestApiTool with the given parameters.
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
auth_scheme and auth_credential.
Args:
name: The name of the tool.
description: The description of the tool.
endpoint: Include the base_url, path, and method of the tool.
operation: Pydantic object or a dict. Representing the OpenAPI Operation
object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
SecurityScheme object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
auth_credential: The authentication credential of the tool.
should_parse_operation: Whether to parse the operation.
"""
# Gemini restrict the length of function name to be less than 64 characters
self.name = name[:60]
self.description = description
self.endpoint = (
OperationEndpoint.model_validate_json(endpoint)
if isinstance(endpoint, str)
else endpoint
)
self.operation = (
Operation.model_validate_json(operation)
if isinstance(operation, str)
else operation
)
self.auth_credential, self.auth_scheme = None, None
self.configure_auth_credential(auth_credential)
self.configure_auth_scheme(auth_scheme)
# Private properties
self.credential_exchanger = AutoAuthCredentialExchanger()
if should_parse_operation:
self._operation_parser = OperationParser(self.operation)
@classmethod
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
"""Initializes the RestApiTool from a ParsedOperation object.
Args:
parsed: A ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation_parser = OperationParser.load(
parsed.operation, parsed.parameters, parsed.return_value
)
tool_name = to_snake_case(operation_parser.get_function_name())
generated = cls(
name=tool_name,
description=parsed.operation.description
or parsed.operation.summary
or "",
endpoint=parsed.endpoint,
operation=parsed.operation,
auth_scheme=parsed.auth_scheme,
auth_credential=parsed.auth_credential,
)
generated._operation_parser = operation_parser
return generated
@classmethod
def from_parsed_operation_str(
cls, parsed_operation_str: str
) -> "RestApiTool":
"""Initializes the RestApiTool from a dict.
Args:
parsed: A dict representation of a ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation = ParsedOperation.model_validate_json(parsed_operation_str)
return RestApiTool.from_parsed_operation(operation)
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
def configure_auth_scheme(
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
):
"""Configures the authentication scheme for the API call.
Args:
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
converted to a AuthScheme object.
"""
if isinstance(auth_scheme, dict):
auth_scheme = dict_to_auth_scheme(auth_scheme)
self.auth_scheme = auth_scheme
def configure_auth_credential(
self, auth_credential: Optional[Union[AuthCredential, str]] = None
):
"""Configures the authentication credential for the API call.
Args:
auth_credential: AuthCredential|dict - The authentication credential.
The dict is converted to an AuthCredential object.
"""
if isinstance(auth_credential, str):
auth_credential = AuthCredential.model_validate_json(auth_credential)
self.auth_credential = auth_credential
def _prepare_auth_request_params(
self,
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
# Handle Authentication
if not auth_scheme or not auth_credential:
return
return credential_to_param(auth_scheme, auth_credential)
def _prepare_request_params(
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepares the request parameters for the API call.
Args:
parameters: A list of ApiParameter objects representing the parameters
for the API call.
kwargs: The keyword arguments passed to the call function from the Tool
caller.
Returns:
A dictionary containing the request parameters for the API call. This
initializes a requests.request() call.
Example:
self._prepare_request_params({"input_id": "test-id"})
"""
method = self.endpoint.method.lower()
if not method:
raise ValueError("Operation method not found.")
path_params: Dict[str, Any] = {}
query_params: Dict[str, Any] = {}
header_params: Dict[str, Any] = {}
cookie_params: Dict[str, Any] = {}
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
# Fill in path, query, header and cookie parameters to the request
for param_k, v in kwargs.items():
param_obj = params_map.get(param_k)
if not param_obj:
continue # If input arg not in the ApiParameter list, ignore it.
original_k = param_obj.original_name
param_location = param_obj.param_location
if param_location == "path":
path_params[original_k] = v
elif param_location == "query":
if v:
query_params[original_k] = v
elif param_location == "header":
header_params[original_k] = v
elif param_location == "cookie":
cookie_params[original_k] = v
# Construct URL
base_url = self.endpoint.base_url or ""
base_url = base_url[:-1] if base_url.endswith("/") else base_url
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
# Construct body
body_kwargs: Dict[str, Any] = {}
request_body = self.operation.requestBody
if request_body:
for mime_type, media_type_object in request_body.content.items():
schema = media_type_object.schema_
body_data = None
if schema.type == "object":
body_data = {}
for param in parameters:
if param.param_location == "body" and param.py_name in kwargs:
body_data[param.original_name] = kwargs[param.py_name]
elif schema.type == "array":
for param in parameters:
if param.param_location == "body" and param.py_name == "array":
body_data = kwargs.get("array")
break
else: # like string
for param in parameters:
# original_name = '' indicating this param applies to the full body.
if param.param_location == "body" and not param.original_name:
body_data = (
kwargs.get(param.py_name) if param.py_name in kwargs else None
)
break
if mime_type == "application/json" or mime_type.endswith("+json"):
if body_data is not None:
body_kwargs["json"] = body_data
elif mime_type == "application/x-www-form-urlencoded":
body_kwargs["data"] = body_data
elif mime_type == "multipart/form-data":
body_kwargs["files"] = body_data
elif mime_type == "application/octet-stream":
body_kwargs["data"] = body_data
elif mime_type == "text/plain":
body_kwargs["data"] = body_data
if mime_type:
header_params["Content-Type"] = mime_type
break # Process only the first mime_type
filtered_query_params: Dict[str, Any] = {
k: v for k, v in query_params.items() if v is not None
}
request_params: Dict[str, Any] = {
"method": method,
"url": url,
"params": filtered_query_params,
"headers": header_params,
"cookies": cookie_params,
**body_kwargs,
}
return request_params
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
return self.call(args=args, tool_context=tool_context)
def call(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
"""Executes the REST API call.
Args:
args: Keyword arguments representing the operation parameters.
tool_context: The tool context (not used here, but required by the
interface).
Returns:
The API response as a dictionary.
"""
# Prepare auth credentials for the API call
tool_auth_handler = ToolAuthHandler.from_tool_context(
tool_context, self.auth_scheme, self.auth_credential
)
auth_result = tool_auth_handler.prepare_auth_credentials()
auth_state, auth_scheme, auth_credential = (
auth_result.state,
auth_result.auth_scheme,
auth_result.auth_credential,
)
if auth_state == "pending":
return {
"pending": True,
"message": "Needs your authorization to access your data.",
}
# Attach parameters from auth into main parameters list
api_params, api_args = self._operation_parser.get_parameters().copy(), args
if auth_credential:
# Attach parameters from auth into main parameters list
auth_param, auth_args = self._prepare_auth_request_params(
auth_scheme, auth_credential
)
if auth_param and auth_args:
api_params = [auth_param] + api_params
api_args.update(auth_args)
# Got all parameters. Call the API.
request_params = self._prepare_request_params(api_params, api_args)
response = requests.request(**request_params)
# Parse API response
try:
response.raise_for_status() # Raise HTTPError for bad responses
return response.json() # Try to decode JSON
except requests.exceptions.HTTPError:
error_details = response.content.decode("utf-8")
return {
"error": (
f"Tool {self.name} execution failed. Analyze this execution error"
" and your inputs. Retry with adjustments if applicable. But"
" make sure don't retry more than 3 times. Execution Error:"
f" {error_details}"
)
}
except ValueError:
return {"text": response.text} # Return text if not JSON
def __str__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}")'
)
def __repr__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}", operation="{self.operation}",'
f' auth_scheme="{self.auth_scheme}",'
f' auth_credential="{self.auth_credential}")'
)

View File

@@ -0,0 +1,268 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Literal
from typing import Optional
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_credential import AuthCredentialTypes
from ....auth.auth_schemes import AuthScheme
from ....auth.auth_schemes import AuthSchemeType
from ....auth.auth_tool import AuthConfig
from ...tool_context import ToolContext
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
logger = logging.getLogger(__name__)
AuthPreparationState = Literal["pending", "done"]
class AuthPreparationResult(BaseModel):
"""Result of the credential preparation process."""
state: AuthPreparationState
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
class ToolContextCredentialStore:
"""Handles storage and retrieval of credentials within a ToolContext."""
def __init__(self, tool_context: ToolContext):
self.tool_context = tool_context
def get_credential_key(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> str:
"""Generates a unique key for the given auth scheme and credential."""
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
# no need to prepend temp: namespace, session state is a copy, changes to
# it won't be persisted , only changes in event_action.state_delta will be
# persisted. temp: namespace will be cleared after current run. but tool
# want access token to be there stored across runs
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
def get_credential(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> Optional[AuthCredential]:
if not self.tool_context:
return None
token_key = self.get_credential_key(auth_scheme, auth_credential)
# TODO try not to use session state, this looks a hacky way, depend on
# session implementation, we don't want session to persist the token,
# meanwhile we want the token shared across runs.
serialized_credential = self.tool_context.state.get(token_key)
if not serialized_credential:
return None
return AuthCredential.model_validate(serialized_credential)
def store_credential(
self,
key: str,
auth_credential: Optional[AuthCredential],
):
if self.tool_context:
serializable_credential = jsonable_encoder(
auth_credential, exclude_none=True
)
self.tool_context.state[key] = serializable_credential
def remove_credential(self, key: str):
del self.tool_context.state[key]
class ToolAuthHandler:
"""Handles the preparation and exchange of authentication credentials for tools."""
def __init__(
self,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
credential_store: Optional["ToolContextCredentialStore"] = None,
):
self.tool_context = tool_context
self.auth_scheme = (
auth_scheme.model_copy(deep=True) if auth_scheme else None
)
self.auth_credential = (
auth_credential.model_copy(deep=True) if auth_credential else None
)
self.credential_exchanger = (
credential_exchanger or AutoAuthCredentialExchanger()
)
self.credential_store = credential_store
self.should_store_credential = True
@classmethod
def from_tool_context(
cls,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
) -> "ToolAuthHandler":
"""Creates a ToolAuthHandler instance from a ToolContext."""
credential_store = ToolContextCredentialStore(tool_context)
return cls(
tool_context,
auth_scheme,
auth_credential,
credential_exchanger,
credential_store,
)
def _handle_existing_credential(
self,
) -> Optional[AuthPreparationResult]:
"""Checks for and returns an existing, exchanged credential."""
if self.credential_store:
existing_credential = self.credential_store.get_credential(
self.auth_scheme, self.auth_credential
)
if existing_credential:
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=existing_credential,
)
return None
def _exchange_credential(
self, auth_credential: AuthCredential
) -> Optional[AuthPreparationResult]:
"""Handles an OpenID Connect authorization response."""
exchanged_credential = None
try:
exchanged_credential = self.credential_exchanger.exchange_credential(
self.auth_scheme, auth_credential
)
except Exception as e:
logger.error("Failed to exchange credential: %s", e)
return exchanged_credential
def _store_credential(self, auth_credential: AuthCredential) -> None:
"""stores the auth_credential."""
if self.credential_store:
key = self.credential_store.get_credential_key(
self.auth_scheme, self.auth_credential
)
self.credential_store.store_credential(key, auth_credential)
def _reqeust_credential(self) -> None:
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
if self.auth_scheme.type_ in (
AuthSchemeType.openIdConnect,
AuthSchemeType.oauth2,
):
if not self.auth_credential or not self.auth_credential.oauth2:
raise ValueError(
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
"Please create AuthCredential using OAuth2Auth."
)
if not self.auth_credential.oauth2.client_id:
raise AuthCredentialMissingError(
"OAuth2 credentials client_id is missing."
)
if not self.auth_credential.oauth2.client_secret:
raise AuthCredentialMissingError(
"OAuth2 credentials client_secret is missing."
)
self.tool_context.request_credential(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
return None
def _get_auth_response(self) -> AuthCredential:
return self.tool_context.get_auth_response(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
def _request_credential(self, auth_config: AuthConfig):
if not self.tool_context:
return
self.tool_context.request_credential(auth_config)
def prepare_auth_credentials(
self,
) -> AuthPreparationResult:
"""Prepares authentication credentials, handling exchange and user interaction."""
# no auth is needed
if not self.auth_scheme:
return AuthPreparationResult(state="done")
# Check for existing credential.
existing_result = self._handle_existing_credential()
if existing_result:
return existing_result
# fetch credential from adk framework
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
# multi-step exchange:
# client_id , client_secret -> auth_uri -> auth_code -> access_token
# -> bearer token
# adk framework supports exchange access_token already
fetched_credential = self._get_auth_response() or self.auth_credential
exchanged_credential = self._exchange_credential(fetched_credential)
if exchanged_credential:
self._store_credential(exchanged_credential)
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=exchanged_credential,
)
else:
self._reqeust_credential()
return AuthPreparationResult(
state="pending",
auth_scheme=self.auth_scheme,
auth_credential=self.auth_credential,
)

View File

@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from typing_extensions import override
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
class PreloadMemoryTool(BaseTool):
"""A tool that preloads the memory for the current user."""
def __init__(self):
# Name and description are not used because this tool only
# changes llm_request.
super().__init__(name='preload_memory', description='preload_memory')
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
parts = tool_context.user_content.parts
if not parts or not parts[0].text:
return
query = parts[0].text
response = tool_context.search_memory(query)
if not response.memories:
return
memory_text = ''
for memory in response.memories:
time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat()
memory_text += f'Time: {time_str}\n'
for event in memory.events:
# TODO: support multi-part content.
if (
event.content
and event.content.parts
and event.content.parts[0].text
):
memory_text += f'{event.author}: {event.content.parts[0].text}\n'
si = f"""The following content is from your previous conversations with the user.
They may be useful for answering the user's current query.
<PAST_CONVERSATIONS>
{memory_text}
</PAST_CONVERSATIONS>
"""
llm_request.append_instructions([si])
preload_memory_tool = PreloadMemoryTool()

View File

@@ -0,0 +1,36 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_retrieval_tool import BaseRetrievalTool
from .files_retrieval import FilesRetrieval
from .llama_index_retrieval import LlamaIndexRetrieval
__all__ = [
'BaseRetrievalTool',
'FilesRetrieval',
'LlamaIndexRetrieval',
]
try:
from .vertex_ai_rag_retrieval import VertexAiRagRetrieval
__all__.append('VertexAiRagRetrieval')
except ImportError:
import logging
logger = logging.getLogger(__name__)
logger.debug(
'The Vertex sdk is not installed. If you want to use the Vertex RAG with'
' agents, please install it. If not, you can ignore this warning.'
)

View File

@@ -0,0 +1,37 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.genai import types
from typing_extensions import override
from ..base_tool import BaseTool
class BaseRetrievalTool(BaseTool):
@override
def _get_declaration(self) -> types.FunctionDeclaration:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'query': types.Schema(
type=types.Type.STRING,
description='The query to retrieve.',
),
},
),
)

View File

@@ -0,0 +1,33 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides data for the agent."""
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex
from .llama_index_retrieval import LlamaIndexRetrieval
class FilesRetrieval(LlamaIndexRetrieval):
def __init__(self, *, name: str, description: str, input_dir: str):
self.input_dir = input_dir
print(f'Loading data from {input_dir}')
retriever = VectorStoreIndex.from_documents(
SimpleDirectoryReader(input_dir).load_data()
).as_retriever()
super().__init__(name=name, description=description, retriever=retriever)

View File

@@ -0,0 +1,41 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides data for the agent."""
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
from typing_extensions import override
from ..tool_context import ToolContext
from .base_retrieval_tool import BaseRetrievalTool
if TYPE_CHECKING:
from llama_index.core.base.base_retriever import BaseRetriever
class LlamaIndexRetrieval(BaseRetrievalTool):
def __init__(self, *, name: str, description: str, retriever: BaseRetriever):
super().__init__(name=name, description=description)
self.retriever = retriever
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
return self.retriever.retrieve(args['query'])[0].text

View File

@@ -0,0 +1,107 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A retrieval tool that uses Vertex AI RAG to retrieve data."""
from __future__ import annotations
import logging
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from vertexai.preview import rag
from ..tool_context import ToolContext
from .base_retrieval_tool import BaseRetrievalTool
if TYPE_CHECKING:
from ...models.llm_request import LlmRequest
logger = logging.getLogger(__name__)
class VertexAiRagRetrieval(BaseRetrievalTool):
"""A retrieval tool that uses Vertex AI RAG (Retrieval-Augmented Generation) to retrieve data."""
def __init__(
self,
*,
name: str,
description: str,
rag_corpora: list[str] = None,
rag_resources: list[rag.RagResource] = None,
similarity_top_k: int = None,
vector_distance_threshold: float = None,
):
super().__init__(name=name, description=description)
self.vertex_rag_store = types.VertexRagStore(
rag_corpora=rag_corpora,
rag_resources=rag_resources,
similarity_top_k=similarity_top_k,
vector_distance_threshold=vector_distance_threshold,
)
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
# Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
if llm_request.model and llm_request.model.startswith('gemini-2'):
llm_request.config = (
types.GenerateContentConfig()
if not llm_request.config
else llm_request.config
)
llm_request.config.tools = (
[] if not llm_request.config.tools else llm_request.config.tools
)
llm_request.config.tools.append(
types.Tool(
retrieval=types.Retrieval(vertex_rag_store=self.vertex_rag_store)
)
)
else:
# Add the function declaration to the tools
await super().process_llm_request(
tool_context=tool_context, llm_request=llm_request
)
@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
response = rag.retrieval_query(
text=args['query'],
rag_resources=self.vertex_rag_store.rag_resources,
rag_corpora=self.vertex_rag_store.rag_corpora,
similarity_top_k=self.vertex_rag_store.similarity_top_k,
vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
)
logging.debug('RAG raw response: %s', response)
return (
f'No matching result found with the config: {self.vertex_rag_store}'
if not response.contexts.contexts
else [context.text for context in response.contexts.contexts]
)

View File

@@ -0,0 +1,90 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from typing import TYPE_CHECKING
from ..agents.callback_context import CallbackContext
from ..auth.auth_credential import AuthCredential
from ..auth.auth_handler import AuthHandler
from ..auth.auth_tool import AuthConfig
if TYPE_CHECKING:
from ..agents.invocation_context import InvocationContext
from ..events.event_actions import EventActions
from ..memory.base_memory_service import SearchMemoryResponse
class ToolContext(CallbackContext):
"""The context of the tool.
This class provides the context for a tool invocation, including access to
the invocation context, function call ID, event actions, and authentication
response. It also provides methods for requesting credentials, retrieving
authentication responses, listing artifacts, and searching memory.
Attributes:
invocation_context: The invocation context of the tool.
function_call_id: The function call id of the current tool call. This id was
returned in the function call event from LLM to identify a function call.
If LLM didn't return this id, ADK will assign one to it. This id is used
to map function call response to the original function call.
event_actions: The event actions of the current tool call.
"""
def __init__(
self,
invocation_context: InvocationContext,
*,
function_call_id: Optional[str] = None,
event_actions: Optional[EventActions] = None,
):
super().__init__(invocation_context, event_actions=event_actions)
self.function_call_id = function_call_id
@property
def actions(self) -> EventActions:
return self._event_actions
def request_credential(self, auth_config: AuthConfig) -> None:
if not self.function_call_id:
raise ValueError('function_call_id is not set.')
self._event_actions.requested_auth_configs[self.function_call_id] = (
AuthHandler(auth_config).generate_auth_request()
)
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
return AuthHandler(auth_config).get_auth_response(self.state)
def list_artifacts(self) -> list[str]:
"""Lists the filenames of the artifacts attached to the current session."""
if self._invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
return self._invocation_context.artifact_service.list_artifact_keys(
app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id,
)
def search_memory(self, query: str) -> 'SearchMemoryResponse':
"""Searches the memory of the current user."""
if self._invocation_context.memory_service is None:
raise ValueError('Memory service is not available.')
return self._invocation_context.memory_service.search_memory(
app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id,
query=query,
)

View File

@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from . import _automatic_function_calling_util
from .langchain_tool import LangchainTool
class ToolboxTool:
"""A class that provides access to toolbox tools.
Example:
```python
toolbox = ToolboxTool("http://127.0.0.1:8080")
tool = toolbox.get_tool("tool_name")
toolset = toolbox.get_toolset("toolset_name")
```
"""
toolbox_client: Any
"""The toolbox client."""
def __init__(self, url: str):
from toolbox_langchain import ToolboxClient
self.toolbox_client = ToolboxClient(url)
def get_tool(self, tool_name: str) -> LangchainTool:
tool = self.toolbox_client.load_tool(tool_name)
return LangchainTool(tool)
def get_toolset(self, toolset_name: str) -> list[LangchainTool]:
tools = self.toolbox_client.load_toolset(toolset_name)
return [LangchainTool(tool) for tool in tools]

View File

@@ -0,0 +1,21 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tool_context import ToolContext
# TODO: make this internal, since user doesn't need to use this tool directly.
def transfer_to_agent(agent_name: str, tool_context: ToolContext):
"""Transfer the question to another agent."""
tool_context.actions.transfer_to_agent = agent_name

View File

@@ -0,0 +1,96 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
class VertexAiSearchTool(BaseTool):
"""A built-in tool using Vertex AI Search.
Attributes:
data_store_id: The Vertex AI search data store resource ID.
search_engine_id: The Vertex AI search engine resource ID.
"""
def __init__(
self,
*,
data_store_id: Optional[str] = None,
search_engine_id: Optional[str] = None,
):
"""Initializes the Vertex AI Search tool.
Args:
data_store_id: The Vertex AI search data store resource ID in the format
of
"projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}".
search_engine_id: The Vertex AI search engine resource ID in the format of
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
Raises:
ValueError: If both data_store_id and search_engine_id are not specified
or both are specified.
"""
# Name and description are not used because this is a model built-in tool.
super().__init__(name='vertex_ai_search', description='vertex_ai_search')
if (data_store_id is None and search_engine_id is None) or (
data_store_id is not None and search_engine_id is not None
):
raise ValueError(
'Either data_store_id or search_engine_id must be specified.'
)
self.data_store_id = data_store_id
self.search_engine_id = search_engine_id
@override
async def process_llm_request(
self,
*,
tool_context: ToolContext,
llm_request: LlmRequest,
) -> None:
if llm_request.model and llm_request.model.startswith('gemini-'):
if llm_request.model.startswith('gemini-1') and llm_request.config.tools:
raise ValueError(
'Vertex AI search tool can not be used with other tools in Gemini'
' 1.x.'
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
llm_request.config.tools.append(
types.Tool(
retrieval=types.Retrieval(
vertex_ai_search=types.VertexAISearch(
datastore=self.data_store_id, engine=self.search_engine_id
)
)
)
)
else:
raise ValueError(
'Vertex AI search tool is not supported for model'
f' {llm_request.model}'
)