mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-27 15:17:44 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
51
src/google/adk/tools/__init__.py
Normal file
51
src/google/adk/tools/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=g-bad-import-order
|
||||
from .base_tool import BaseTool
|
||||
|
||||
from ..auth.auth_tool import AuthToolArguments
|
||||
from .apihub_tool.apihub_toolset import APIHubToolset
|
||||
from .built_in_code_execution_tool import built_in_code_execution
|
||||
from .google_search_tool import google_search
|
||||
from .vertex_ai_search_tool import VertexAiSearchTool
|
||||
from .example_tool import ExampleTool
|
||||
from .exit_loop_tool import exit_loop
|
||||
from .function_tool import FunctionTool
|
||||
from .get_user_choice_tool import get_user_choice_tool as get_user_choice
|
||||
from .load_artifacts_tool import load_artifacts_tool as load_artifacts
|
||||
from .load_memory_tool import load_memory_tool as load_memory
|
||||
from .long_running_tool import LongRunningFunctionTool
|
||||
from .preload_memory_tool import preload_memory_tool as preload_memory
|
||||
from .tool_context import ToolContext
|
||||
from .transfer_to_agent_tool import transfer_to_agent
|
||||
|
||||
|
||||
__all__ = [
|
||||
'APIHubToolset',
|
||||
'AuthToolArguments',
|
||||
'BaseTool',
|
||||
'built_in_code_execution',
|
||||
'google_search',
|
||||
'VertexAiSearchTool',
|
||||
'ExampleTool',
|
||||
'exit_loop',
|
||||
'FunctionTool',
|
||||
'get_user_choice',
|
||||
'load_artifacts',
|
||||
'load_memory',
|
||||
'LongRunningFunctionTool',
|
||||
'preload_memory',
|
||||
'ToolContext',
|
||||
'transfer_to_agent',
|
||||
]
|
||||
346
src/google/adk/tools/_automatic_function_calling_util.py
Normal file
346
src/google/adk/tools/_automatic_function_calling_util.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Forked from google3/third_party/py/google/genai/_automatic_function_calling_util.py temporarily."""
|
||||
|
||||
import inspect
|
||||
from types import FunctionType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic import create_model
|
||||
from pydantic import fields as pydantic_fields
|
||||
|
||||
from . import function_parameter_parse_util
|
||||
|
||||
_py_type_2_schema_type = {
|
||||
'str': types.Type.STRING,
|
||||
'int': types.Type.INTEGER,
|
||||
'float': types.Type.NUMBER,
|
||||
'bool': types.Type.BOOLEAN,
|
||||
'string': types.Type.STRING,
|
||||
'integer': types.Type.INTEGER,
|
||||
'number': types.Type.NUMBER,
|
||||
'boolean': types.Type.BOOLEAN,
|
||||
'list': types.Type.ARRAY,
|
||||
'array': types.Type.ARRAY,
|
||||
'tuple': types.Type.ARRAY,
|
||||
'object': types.Type.OBJECT,
|
||||
'Dict': types.Type.OBJECT,
|
||||
'List': types.Type.ARRAY,
|
||||
'Tuple': types.Type.ARRAY,
|
||||
'Any': types.Type.TYPE_UNSPECIFIED,
|
||||
}
|
||||
|
||||
|
||||
def _get_fields_dict(func: Callable) -> Dict:
|
||||
param_signature = dict(inspect.signature(func).parameters)
|
||||
fields_dict = {
|
||||
name: (
|
||||
# 1. We infer the argument type here: use Any rather than None so
|
||||
# it will not try to auto-infer the type based on the default value.
|
||||
(
|
||||
param.annotation
|
||||
if param.annotation != inspect.Parameter.empty
|
||||
else Any
|
||||
),
|
||||
pydantic.Field(
|
||||
# 2. We do not support default values for now.
|
||||
default=(
|
||||
param.default
|
||||
if param.default != inspect.Parameter.empty
|
||||
# ! Need to use Undefined instead of None
|
||||
else pydantic_fields.PydanticUndefined
|
||||
),
|
||||
# 3. Do not support parameter description for now.
|
||||
description=None,
|
||||
),
|
||||
)
|
||||
for name, param in param_signature.items()
|
||||
# We do not support *args or **kwargs
|
||||
if param.kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
}
|
||||
return fields_dict
|
||||
|
||||
|
||||
def _annotate_nullable_fields(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
# for Optional[T], the pydantic schema is:
|
||||
# {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "anyOf": [
|
||||
# {
|
||||
# "type": "null"
|
||||
# },
|
||||
# {
|
||||
# "type": "T"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
for type_ in property_schema.get('anyOf', []):
|
||||
if type_.get('type') == 'null':
|
||||
property_schema['nullable'] = True
|
||||
property_schema['anyOf'].remove(type_)
|
||||
break
|
||||
|
||||
|
||||
def _annotate_required_fields(schema: Dict):
|
||||
required = [
|
||||
field_name
|
||||
for field_name, field_schema in schema.get('properties', {}).items()
|
||||
if not field_schema.get('nullable') and 'default' not in field_schema
|
||||
]
|
||||
schema['required'] = required
|
||||
|
||||
|
||||
def _remove_any_of(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
union_types = property_schema.pop('anyOf', None)
|
||||
# Take the first non-null type.
|
||||
if union_types:
|
||||
for type_ in union_types:
|
||||
if type_.get('type') != 'null':
|
||||
property_schema.update(type_)
|
||||
|
||||
|
||||
def _remove_default(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('default', None)
|
||||
|
||||
|
||||
def _remove_nullable(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('nullable', None)
|
||||
|
||||
|
||||
def _remove_title(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('title', None)
|
||||
|
||||
|
||||
def _get_pydantic_schema(func: Callable) -> Dict:
|
||||
fields_dict = _get_fields_dict(func)
|
||||
if 'tool_context' in fields_dict.keys():
|
||||
fields_dict.pop('tool_context')
|
||||
return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
|
||||
|
||||
|
||||
def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
|
||||
_annotate_nullable_fields(schema)
|
||||
_annotate_required_fields(schema)
|
||||
if not vertexai:
|
||||
_remove_any_of(schema)
|
||||
_remove_default(schema)
|
||||
_remove_nullable(schema)
|
||||
_remove_title(schema)
|
||||
return schema
|
||||
|
||||
|
||||
def _map_pydantic_type_to_property_schema(property_schema: Dict):
|
||||
if 'type' in property_schema:
|
||||
property_schema['type'] = _py_type_2_schema_type.get(
|
||||
property_schema['type'], 'TYPE_UNSPECIFIED'
|
||||
)
|
||||
if property_schema['type'] == 'ARRAY':
|
||||
_map_pydantic_type_to_property_schema(property_schema['items'])
|
||||
for type_ in property_schema.get('anyOf', []):
|
||||
if 'type' in type_:
|
||||
type_['type'] = _py_type_2_schema_type.get(
|
||||
type_['type'], 'TYPE_UNSPECIFIED'
|
||||
)
|
||||
# TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
|
||||
# avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
|
||||
property_schema['type'] = type_['type']
|
||||
|
||||
|
||||
def _map_pydantic_type_to_schema_type(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
_map_pydantic_type_to_property_schema(property_schema)
|
||||
|
||||
|
||||
def _get_return_type(func: Callable) -> Any:
|
||||
return _py_type_2_schema_type.get(
|
||||
inspect.signature(func).return_annotation.__name__,
|
||||
inspect.signature(func).return_annotation.__name__,
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration(
|
||||
func: Union[Callable, BaseModel],
|
||||
ignore_params: Optional[list[str]] = None,
|
||||
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
|
||||
) -> types.FunctionDeclaration:
|
||||
signature = inspect.signature(func)
|
||||
should_update_signature = False
|
||||
new_func = None
|
||||
if not ignore_params:
|
||||
ignore_params = []
|
||||
for name, _ in signature.parameters.items():
|
||||
if name in ignore_params:
|
||||
should_update_signature = True
|
||||
break
|
||||
if should_update_signature:
|
||||
new_params = [
|
||||
param
|
||||
for name, param in signature.parameters.items()
|
||||
if name not in ignore_params
|
||||
]
|
||||
if isinstance(func, type):
|
||||
fields = {
|
||||
name: (param.annotation, param.default)
|
||||
for name, param in signature.parameters.items()
|
||||
if name not in ignore_params
|
||||
}
|
||||
new_func = create_model(func.__name__, **fields)
|
||||
else:
|
||||
new_sig = signature.replace(parameters=new_params)
|
||||
new_func = FunctionType(
|
||||
func.__code__,
|
||||
func.__globals__,
|
||||
func.__name__,
|
||||
func.__defaults__,
|
||||
func.__closure__,
|
||||
)
|
||||
new_func.__signature__ = new_sig
|
||||
|
||||
return (
|
||||
from_function_with_options(func, variant)
|
||||
if not should_update_signature
|
||||
else from_function_with_options(new_func, variant)
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_for_langchain(
|
||||
vertexai: bool, name, description, func, param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
param_pydantic_schema = _process_pydantic_schema(
|
||||
vertexai, {'properties': param_pydantic_schema}
|
||||
)['properties']
|
||||
param_copy = param_pydantic_schema.copy()
|
||||
required_fields = param_copy.pop('required', [])
|
||||
before_param_pydantic_schema = {
|
||||
'properties': param_copy,
|
||||
'required': required_fields,
|
||||
}
|
||||
return build_function_declaration_util(
|
||||
vertexai, name, description, func, before_param_pydantic_schema
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_for_params_for_crewai(
|
||||
vertexai: bool, name, description, func, param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
param_pydantic_schema = _process_pydantic_schema(
|
||||
vertexai, param_pydantic_schema
|
||||
)
|
||||
param_copy = param_pydantic_schema.copy()
|
||||
return build_function_declaration_util(
|
||||
vertexai, name, description, func, param_copy
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_util(
|
||||
vertexai: bool, name, description, func, before_param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
_map_pydantic_type_to_schema_type(before_param_pydantic_schema)
|
||||
properties = before_param_pydantic_schema.get('properties', {})
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type='OBJECT',
|
||||
properties=properties,
|
||||
)
|
||||
if properties
|
||||
else None,
|
||||
description=description,
|
||||
name=name,
|
||||
)
|
||||
if vertexai and isinstance(func, Callable):
|
||||
return_pydantic_schema = _get_return_type(func)
|
||||
function_declaration.response = types.Schema(
|
||||
type=return_pydantic_schema,
|
||||
)
|
||||
return function_declaration
|
||||
|
||||
|
||||
def from_function_with_options(
|
||||
func: Callable,
|
||||
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
|
||||
) -> 'types.FunctionDeclaration':
|
||||
|
||||
supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
|
||||
if variant not in supported_variants:
|
||||
raise ValueError(
|
||||
f'Unsupported variant: {variant}. Supported variants are:'
|
||||
f' {", ".join(supported_variants)}'
|
||||
)
|
||||
|
||||
parameters_properties = {}
|
||||
for name, param in inspect.signature(func).parameters.items():
|
||||
if param.kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
):
|
||||
schema = function_parameter_parse_util._parse_schema_from_parameter(
|
||||
variant, param, func.__name__
|
||||
)
|
||||
parameters_properties[name] = schema
|
||||
declaration = types.FunctionDeclaration(
|
||||
name=func.__name__,
|
||||
description=func.__doc__,
|
||||
)
|
||||
if parameters_properties:
|
||||
declaration.parameters = types.Schema(
|
||||
type='OBJECT',
|
||||
properties=parameters_properties,
|
||||
)
|
||||
if variant == 'VERTEX_AI':
|
||||
declaration.parameters.required = (
|
||||
function_parameter_parse_util._get_required_fields(
|
||||
declaration.parameters
|
||||
)
|
||||
)
|
||||
if not variant == 'VERTEX_AI':
|
||||
return declaration
|
||||
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
if return_annotation is inspect._empty:
|
||||
return declaration
|
||||
|
||||
declaration.response = (
|
||||
function_parameter_parse_util._parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'return_value',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=return_annotation,
|
||||
),
|
||||
func.__name__,
|
||||
)
|
||||
)
|
||||
return declaration
|
||||
176
src/google/adk/tools/agent_tool.py
Normal file
176
src/google/adk/tools/agent_tool.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from . import _automatic_function_calling_util
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
class AgentTool(BaseTool):
|
||||
"""A tool that wraps an agent.
|
||||
|
||||
This tool allows an agent to be called as a tool within a larger application.
|
||||
The agent's input schema is used to define the tool's input parameters, and
|
||||
the agent's output is returned as the tool's result.
|
||||
|
||||
Attributes:
|
||||
agent: The agent to wrap.
|
||||
skip_summarization: Whether to skip summarization of the agent output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: BaseAgent):
|
||||
self.agent = agent
|
||||
self.skip_summarization: bool = False
|
||||
"""Whether to skip summarization of the agent output."""
|
||||
|
||||
super().__init__(name=agent.name, description=agent.description)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def populate_name(cls, data: Any) -> Any:
|
||||
data['name'] = data['agent'].name
|
||||
return data
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
result = _automatic_function_calling_util.build_function_declaration(
|
||||
func=self.agent.input_schema, variant=self._api_variant
|
||||
)
|
||||
else:
|
||||
result = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'request': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
},
|
||||
required=['request'],
|
||||
),
|
||||
description=self.agent.description,
|
||||
name=self.name,
|
||||
)
|
||||
result.name = self.name
|
||||
return result
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
if self.skip_summarization:
|
||||
tool_context.actions.skip_summarization = True
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
input_value = self.agent.input_schema.model_validate(args)
|
||||
else:
|
||||
input_value = args['request']
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
if isinstance(input_value, dict):
|
||||
input_value = self.agent.input_schema.model_validate(input_value)
|
||||
if not isinstance(input_value, self.agent.input_schema):
|
||||
raise ValueError(
|
||||
f'Input value {input_value} is not of type'
|
||||
f' `{self.agent.input_schema}`.'
|
||||
)
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=input_value.model_dump_json(exclude_none=True)
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[types.Part.from_text(text=input_value)],
|
||||
)
|
||||
runner = Runner(
|
||||
app_name=self.agent.name,
|
||||
agent=self.agent,
|
||||
# TODO(kech): Remove the access to the invocation context.
|
||||
# It seems we don't need re-use artifact_service if we forward below.
|
||||
artifact_service=tool_context._invocation_context.artifact_service,
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
session = runner.session_service.create_session(
|
||||
app_name=self.agent.name,
|
||||
user_id='tmp_user',
|
||||
state=tool_context.state.to_dict(),
|
||||
)
|
||||
|
||||
last_event = None
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
# Forward state delta to parent session.
|
||||
if event.actions.state_delta:
|
||||
tool_context.state.update(event.actions.state_delta)
|
||||
last_event = event
|
||||
|
||||
if runner.artifact_service:
|
||||
# Forward all artifacts to parent session.
|
||||
for artifact_name in runner.artifact_service.list_artifact_keys(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
):
|
||||
if artifact := runner.artifact_service.load_artifact(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
filename=artifact_name,
|
||||
):
|
||||
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
|
||||
|
||||
if (
|
||||
not last_event
|
||||
or not last_event.content
|
||||
or not last_event.content.parts
|
||||
or not last_event.content.parts[0].text
|
||||
):
|
||||
return ''
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
|
||||
tool_result = self.agent.output_schema.model_validate_json(
|
||||
last_event.content.parts[0].text
|
||||
).model_dump(exclude_none=True)
|
||||
else:
|
||||
tool_result = last_event.content.parts[0].text
|
||||
return tool_result
|
||||
19
src/google/adk/tools/apihub_tool/__init__.py
Normal file
19
src/google/adk/tools/apihub_tool/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .apihub_toolset import APIHubToolset
|
||||
|
||||
__all__ = [
|
||||
'APIHubToolset',
|
||||
]
|
||||
209
src/google/adk/tools/apihub_tool/apihub_toolset.py
Normal file
209
src/google/adk/tools/apihub_tool/apihub_toolset.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..openapi_tool.common.common import to_snake_case
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.apihub_client import APIHubClient
|
||||
|
||||
|
||||
class APIHubToolset:
|
||||
"""APIHubTool generates tools from a given API Hub resource.
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Parameters for fetching API Hub resource
|
||||
apihub_resource_name: str,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
# Parameters for the toolset itself
|
||||
name: str = '',
|
||||
description: str = '',
|
||||
# Parameters for generating tools
|
||||
lazy_load_spec=False,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
# Optionally, you can provide a custom API Hub client
|
||||
apihub_client: Optional[APIHubClient] = None,
|
||||
):
|
||||
"""Initializes the APIHubTool with the given parameters.
|
||||
|
||||
Examples:
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
|
||||
Example:
|
||||
* projects/xxx/locations/us-central1/apis/apiname/...
|
||||
* https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx
|
||||
|
||||
Args:
|
||||
apihub_resource_name: The resource name of the API in API Hub.
|
||||
Example: `projects/test-project/locations/us-central1/apis/test-api`.
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
auth print-access-token`. Used for fetching API Specs from API Hub.
|
||||
service_account_json: The service account config as a json string.
|
||||
Required if not using default service credential. It is used for
|
||||
creating the API Hub client and fetching the API Specs from API Hub.
|
||||
apihub_client: Optional custom API Hub client.
|
||||
name: Name of the toolset. Optional.
|
||||
description: Description of the toolset. Optional.
|
||||
auth_scheme: Auth scheme that applies to all the tool in the toolset.
|
||||
auth_credential: Auth credential that applies to all the tool in the
|
||||
toolset.
|
||||
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
||||
Otherwise, the spec will be loaded immediately and the tools will be
|
||||
generated during initialization.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.apihub_resource_name = apihub_resource_name
|
||||
self.lazy_load_spec = lazy_load_spec
|
||||
self.apihub_client = apihub_client or APIHubClient(
|
||||
access_token=access_token,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
if not self.lazy_load_spec:
|
||||
self._prepare_tools()
|
||||
|
||||
def get_tool(self, name: str) -> Optional[RestApiTool]:
|
||||
"""Retrieves a specific tool by its name.
|
||||
|
||||
Example:
|
||||
```
|
||||
apihub_tool = apihub_toolset.get_tool('my_tool')
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name of the tool to retrieve.
|
||||
|
||||
Returns:
|
||||
The tool with the given name, or None if no such tool exists.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return self.generated_tools[name] if name in self.generated_tools else None
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Retrieves all available tools.
|
||||
|
||||
Returns:
|
||||
A list of all available RestApiTool objects.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return list(self.generated_tools.values())
|
||||
|
||||
def _are_tools_ready(self) -> bool:
|
||||
return not self.lazy_load_spec or self.generated_tools
|
||||
|
||||
def _prepare_tools(self) -> str:
|
||||
"""Fetches the spec from API Hub and generates the tools.
|
||||
|
||||
Returns:
|
||||
True if the tools are ready, False otherwise.
|
||||
"""
|
||||
# For each API, get the first version and the first spec of that version.
|
||||
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
tools = self._parse_spec_to_tools(spec)
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
|
||||
"""Parses the spec string to a list of RestApiTool.
|
||||
|
||||
Args:
|
||||
spec_str: The spec string to parse.
|
||||
|
||||
Returns:
|
||||
A list of RestApiTool objects.
|
||||
"""
|
||||
spec_dict = yaml.safe_load(spec_str)
|
||||
if not spec_dict:
|
||||
return []
|
||||
|
||||
self.name = self.name or to_snake_case(
|
||||
spec_dict.get('info', {}).get('title', 'unnamed')
|
||||
)
|
||||
self.description = self.description or spec_dict.get('info', {}).get(
|
||||
'description', ''
|
||||
)
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=self.auth_credential,
|
||||
auth_scheme=self.auth_scheme,
|
||||
).get_tools()
|
||||
return tools
|
||||
13
src/google/adk/tools/apihub_tool/clients/__init__.py
Normal file
13
src/google/adk/tools/apihub_tool/clients/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
332
src/google/adk/tools/apihub_tool/clients/apihub_client.py
Normal file
332
src/google/adk/tools/apihub_tool/clients/apihub_client.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class BaseAPIHubClient(ABC):
|
||||
"""Base class for API Hub clients."""
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_content(self, resource_name: str) -> str:
|
||||
"""From a given resource name, get the soec in the API Hub."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class APIHubClient(BaseAPIHubClient):
|
||||
"""Client for interacting with the API Hub service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the APIHubClient.
|
||||
|
||||
You must set either access_token or service_account_json. This
|
||||
credential is used for sending request to API Hub API.
|
||||
|
||||
Args:
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
print-access-token`. Useful for local testing.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential.
|
||||
"""
|
||||
self.root_url = "https://apihub.googleapis.com/v1"
|
||||
self.credential_cache = None
|
||||
self.access_token, self.service_account = None, None
|
||||
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
elif service_account_json:
|
||||
self.service_account = service_account_json
|
||||
|
||||
def get_spec_content(self, path: str) -> str:
|
||||
"""From a given path, get the first spec available in the API Hub.
|
||||
|
||||
- If path includes /apis/apiname, get the first spec of that API
|
||||
- If path includes /apis/apiname/versions/versionname, get the first spec
|
||||
of that API Version
|
||||
- If path includes /apis/apiname/versions/versionname/specs/specname, return
|
||||
that spec
|
||||
|
||||
Path can be resource name (projects/xxx/locations/us-central1/apis/apiname),
|
||||
and URL from the UI
|
||||
(https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx)
|
||||
|
||||
Args:
|
||||
path: The path to the API, API Version, or API Spec.
|
||||
|
||||
Returns:
|
||||
The content of the first spec available in the API Hub.
|
||||
"""
|
||||
apihub_resource_name, api_version_resource_name, api_spec_resource_name = (
|
||||
self._extract_resource_name(path)
|
||||
)
|
||||
|
||||
if apihub_resource_name and not api_version_resource_name:
|
||||
api = self.get_api(apihub_resource_name)
|
||||
versions = api.get("versions", [])
|
||||
if not versions:
|
||||
raise ValueError(
|
||||
f"No versions found in API Hub resource: {apihub_resource_name}"
|
||||
)
|
||||
api_version_resource_name = versions[0]
|
||||
|
||||
if api_version_resource_name and not api_spec_resource_name:
|
||||
api_version = self.get_api_version(api_version_resource_name)
|
||||
spec_resource_names = api_version.get("specs", [])
|
||||
if not spec_resource_names:
|
||||
raise ValueError(
|
||||
f"No specs found in API Hub version: {api_version_resource_name}"
|
||||
)
|
||||
api_spec_resource_name = spec_resource_names[0]
|
||||
|
||||
if api_spec_resource_name:
|
||||
spec_content = self._fetch_spec(api_spec_resource_name)
|
||||
return spec_content
|
||||
|
||||
raise ValueError("No API Hub resource found in path: {path}")
|
||||
|
||||
def list_apis(self, project: str, location: str) -> List[Dict[str, Any]]:
|
||||
"""Lists all APIs in the specified project and location.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project name.
|
||||
location: The location of the API Hub resources (e.g., 'us-central1').
|
||||
|
||||
Returns:
|
||||
A list of API dictionaries, or an empty list if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/projects/{project}/locations/{location}/apis"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json().get("apis", [])
|
||||
return apis
|
||||
|
||||
def get_api(self, api_resource_name: str) -> Dict[str, Any]:
|
||||
"""Get API detail by API name.
|
||||
|
||||
Args:
|
||||
api_resource_name: Resource name of this API, like
|
||||
projects/xxx/locations/us-central1/apis/apiname
|
||||
|
||||
Returns:
|
||||
An API and details in a dict.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_resource_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json()
|
||||
return apis
|
||||
|
||||
def get_api_version(self, api_version_name: str) -> Dict[str, Any]:
|
||||
"""Gets details of a specific API version.
|
||||
|
||||
Args:
|
||||
api_version_name: The resource name of the API version.
|
||||
|
||||
Returns:
|
||||
The API version details as a dictionary, or an empty dictionary if an
|
||||
error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_version_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _fetch_spec(self, api_spec_resource_name: str) -> str:
|
||||
"""Retrieves the content of a specific API specification.
|
||||
|
||||
Args:
|
||||
api_spec_resource_name: The resource name of the API spec.
|
||||
|
||||
Returns:
|
||||
The decoded content of the specification as a string, or an empty string
|
||||
if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_spec_resource_name}:contents"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
content_base64 = response.json().get("contents", "")
|
||||
if content_base64:
|
||||
content_decoded = base64.b64decode(content_base64).decode("utf-8")
|
||||
return content_decoded
|
||||
else:
|
||||
return ""
|
||||
|
||||
def _extract_resource_name(self, url_or_path: str) -> Tuple[str, str, str]:
|
||||
"""Extracts the resource names of an API, API Version, and API Spec from a given URL or path.
|
||||
|
||||
Args:
|
||||
url_or_path: The URL (UI or resource) or path string.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the resource names:
|
||||
{
|
||||
"api_resource_name": "projects/*/locations/*/apis/*",
|
||||
"api_version_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*",
|
||||
"api_spec_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*/specs/*"
|
||||
}
|
||||
or raises ValueError if extraction fails.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL or path is invalid or if required components
|
||||
(project, location, api) are missing.
|
||||
"""
|
||||
|
||||
query_params = None
|
||||
try:
|
||||
parsed_url = urlparse(url_or_path)
|
||||
path = parsed_url.path
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
# This is a path from UI. Remove unnecessary prefix.
|
||||
if "api-hub/" in path:
|
||||
path = path.split("api-hub")[1]
|
||||
except Exception:
|
||||
path = url_or_path
|
||||
|
||||
path_segments = [segment for segment in path.split("/") if segment]
|
||||
|
||||
project = None
|
||||
location = None
|
||||
api_id = None
|
||||
version_id = None
|
||||
spec_id = None
|
||||
|
||||
if "projects" in path_segments:
|
||||
project_index = path_segments.index("projects")
|
||||
if project_index + 1 < len(path_segments):
|
||||
project = path_segments[project_index + 1]
|
||||
elif query_params and "project" in query_params:
|
||||
project = query_params["project"][0]
|
||||
|
||||
if not project:
|
||||
raise ValueError(
|
||||
"Project ID not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/projects/PROJECT_ID' in the path or 'project=PROJECT_ID' query"
|
||||
" param in the input."
|
||||
)
|
||||
|
||||
if "locations" in path_segments:
|
||||
location_index = path_segments.index("locations")
|
||||
if location_index + 1 < len(path_segments):
|
||||
location = path_segments[location_index + 1]
|
||||
if not location:
|
||||
raise ValueError(
|
||||
"Location not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/location/LOCATION_ID' in the path."
|
||||
)
|
||||
|
||||
if "apis" in path_segments:
|
||||
api_index = path_segments.index("apis")
|
||||
if api_index + 1 < len(path_segments):
|
||||
api_id = path_segments[api_index + 1]
|
||||
if not api_id:
|
||||
raise ValueError(
|
||||
"API id not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/apis/API_ID' in the path."
|
||||
)
|
||||
if "versions" in path_segments:
|
||||
version_index = path_segments.index("versions")
|
||||
if version_index + 1 < len(path_segments):
|
||||
version_id = path_segments[version_index + 1]
|
||||
|
||||
if "specs" in path_segments:
|
||||
spec_index = path_segments.index("specs")
|
||||
if spec_index + 1 < len(path_segments):
|
||||
spec_id = path_segments[spec_index + 1]
|
||||
|
||||
api_resource_name = f"projects/{project}/locations/{location}/apis/{api_id}"
|
||||
api_version_resource_name = (
|
||||
f"{api_resource_name}/versions/{version_id}" if version_id else None
|
||||
)
|
||||
api_spec_resource_name = (
|
||||
f"{api_version_resource_name}/specs/{spec_id}"
|
||||
if version_id and spec_id
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
api_resource_name,
|
||||
api_version_resource_name,
|
||||
api_spec_resource_name,
|
||||
)
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.access_token:
|
||||
return self.access_token
|
||||
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
115
src/google/adk/tools/apihub_tool/clients/secret_client.py
Normal file
115
src/google/adk/tools/apihub_tool/clients/secret_client.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
import google.auth.transport.requests
|
||||
from google.cloud import secretmanager
|
||||
from google.oauth2 import service_account
|
||||
|
||||
|
||||
class SecretManagerClient:
|
||||
"""A client for interacting with Google Cloud Secret Manager.
|
||||
|
||||
This class provides a simplified interface for retrieving secrets from
|
||||
Secret Manager, handling authentication using either a service account
|
||||
JSON keyfile (passed as a string) or a pre-existing authorization token.
|
||||
|
||||
Attributes:
|
||||
_credentials: Google Cloud credentials object (ServiceAccountCredentials
|
||||
or Credentials).
|
||||
_client: Secret Manager client instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_account_json: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the SecretManagerClient.
|
||||
|
||||
Args:
|
||||
service_account_json: The content of a service account JSON keyfile (as
|
||||
a string), not the file path. Must be valid JSON.
|
||||
auth_token: An existing Google Cloud authorization token.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `service_account_json` nor `auth_token` is
|
||||
provided,
|
||||
or if both are provided. Also raised if the service_account_json
|
||||
is not valid JSON.
|
||||
google.auth.exceptions.GoogleAuthError: If authentication fails.
|
||||
"""
|
||||
if service_account_json:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(service_account_json)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
elif auth_token:
|
||||
credentials = google.auth.credentials.Credentials(
|
||||
token=auth_token,
|
||||
refresh_token=None,
|
||||
token_uri=None,
|
||||
client_id=None,
|
||||
client_secret=None,
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"'service_account_json' or 'auth_token' are both missing, and"
|
||||
f" error occurred while trying to use default credentials: {e}"
|
||||
) from e
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Must provide either 'service_account_json' or 'auth_token', not both"
|
||||
" or neither."
|
||||
)
|
||||
|
||||
self._credentials = credentials
|
||||
self._client = secretmanager.SecretManagerServiceClient(
|
||||
credentials=self._credentials
|
||||
)
|
||||
|
||||
def get_secret(self, resource_name: str) -> str:
|
||||
"""Retrieves a secret from Google Cloud Secret Manager.
|
||||
|
||||
Args:
|
||||
resource_name: The full resource name of the secret, in the format
|
||||
"projects/*/secrets/*/versions/*". Usually you want the "latest"
|
||||
version, e.g.,
|
||||
"projects/my-project/secrets/my-secret/versions/latest".
|
||||
|
||||
Returns:
|
||||
The secret payload as a string.
|
||||
|
||||
Raises:
|
||||
google.api_core.exceptions.GoogleAPIError: If the Secret Manager API
|
||||
returns an error (e.g., secret not found, permission denied).
|
||||
Exception: For other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
response = self._client.access_secret_version(name=resource_name)
|
||||
return response.payload.data.decode("UTF-8")
|
||||
except Exception as e:
|
||||
raise e # Re-raise the exception to allow for handling by the caller
|
||||
# Consider logging the exception here before re-raising.
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .application_integration_toolset import ApplicationIntegrationToolset
|
||||
|
||||
__all__ = [
|
||||
'ApplicationIntegrationToolset',
|
||||
]
|
||||
@@ -0,0 +1,230 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_credential import AuthCredentialTypes
|
||||
from ...auth.auth_credential import ServiceAccount
|
||||
from ...auth.auth_credential import ServiceAccountCredential
|
||||
|
||||
|
||||
# TODO(cheliu): Apply a common toolset interface
|
||||
class ApplicationIntegrationToolset:
|
||||
"""ApplicationIntegrationToolset generates tools from a given Application
|
||||
|
||||
Integration or Integration Connector resource.
|
||||
Example Usage:
|
||||
```
|
||||
# Get all available tools for an integration with api trigger
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
integration="test-integration",
|
||||
trigger="api_trigger/test_trigger",
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools for a connection using entity operations and
|
||||
# actions
|
||||
# Note: Find the list of supported entity operations and actions for a
|
||||
connection
|
||||
# using integration connector apis:
|
||||
#
|
||||
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
connection="test-connection",
|
||||
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
|
||||
#empty list for actions means all operations on the entity are supported
|
||||
actions=["action1"],
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
*application_integration_toolset.get_tools(),
|
||||
])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
integration: Optional[str] = None,
|
||||
trigger: Optional[str] = None,
|
||||
connection: Optional[str] = None,
|
||||
entity_operations: Optional[str] = None,
|
||||
actions: Optional[str] = None,
|
||||
# Optional parameter for the toolset. This is prepended to the generated
|
||||
# tool/python function name.
|
||||
tool_name: Optional[str] = "",
|
||||
# Optional parameter for the toolset. This is appended to the generated
|
||||
# tool/python function description.
|
||||
tool_instructions: Optional[str] = "",
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationToolset.
|
||||
|
||||
Example Usage:
|
||||
```
|
||||
# Get all available tools for an integration with api trigger
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
integration="test-integration",
|
||||
trigger="api_trigger/test_trigger",
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools for a connection using entity operations and
|
||||
# actions
|
||||
# Note: Find the list of supported entity operations and actions for a
|
||||
connection
|
||||
# using integration connector apis:
|
||||
#
|
||||
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
connection="test-connection",
|
||||
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
|
||||
#empty list for actions means all operations on the entity are supported
|
||||
actions=["action1"],
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
*application_integration_toolset.get_tools(),
|
||||
])
|
||||
```
|
||||
|
||||
Args:
|
||||
project: The GCP project ID.
|
||||
location: The GCP location.
|
||||
integration: The integration name.
|
||||
trigger: The trigger name.
|
||||
connection: The connection name.
|
||||
entity_operations: The entity operations supported by the connection.
|
||||
actions: The actions supported by the connection.
|
||||
tool_name: The name of the tool.
|
||||
tool_instructions: The instructions for the tool.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
the Application Integration or Integration Connector resource.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither integration and trigger nor connection and
|
||||
(entity_operations or actions) is provided.
|
||||
Exception: If there is an error during the initialization of the
|
||||
integration or connection client.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.integration = integration
|
||||
self.trigger = trigger
|
||||
self.connection = connection
|
||||
self.entity_operations = entity_operations
|
||||
self.actions = actions
|
||||
self.tool_name = tool_name
|
||||
self.tool_instructions = tool_instructions
|
||||
self.service_account_json = service_account_json
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
integration_client = IntegrationClient(
|
||||
project,
|
||||
location,
|
||||
integration,
|
||||
trigger,
|
||||
connection,
|
||||
entity_operations,
|
||||
actions,
|
||||
service_account_json,
|
||||
)
|
||||
if integration and trigger:
|
||||
spec = integration_client.get_openapi_spec_for_integration()
|
||||
elif connection and (entity_operations or actions):
|
||||
connections_client = ConnectionsClient(
|
||||
project, location, connection, service_account_json
|
||||
)
|
||||
connection_details = connections_client.get_connection_details()
|
||||
tool_instructions += (
|
||||
"ALWAYS use serviceName = "
|
||||
+ connection_details["serviceName"]
|
||||
+ ", host = "
|
||||
+ connection_details["host"]
|
||||
+ " and the connection name = "
|
||||
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
||||
" using this tool"
|
||||
+ ". DONOT ask the user for these values as you already have those."
|
||||
)
|
||||
spec = integration_client.get_openapi_spec_for_connection(
|
||||
tool_name,
|
||||
tool_instructions,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either (integration and trigger) or (connection and"
|
||||
" (entity_operations or actions)) should be provided."
|
||||
)
|
||||
self._parse_spec_to_tools(spec)
|
||||
|
||||
def _parse_spec_to_tools(self, spec_dict):
|
||||
"""Parses the spec dict to a list of RestApiTool."""
|
||||
if self.service_account_json:
|
||||
sa_credential = ServiceAccountCredential.model_validate_json(
|
||||
self.service_account_json
|
||||
)
|
||||
service_account = ServiceAccount(
|
||||
service_account_credential=sa_credential,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
auth_scheme, auth_credential = service_account_scheme_credential(
|
||||
config=service_account
|
||||
)
|
||||
else:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
return list(self.generated_tools.values())
|
||||
@@ -0,0 +1,903 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class ConnectionsClient:
|
||||
"""Utility class for interacting with Google Cloud Connectors API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
connection: str,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ConnectionsClient.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., us-central1).
|
||||
connection: The connection name.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
connection details.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.connection = connection
|
||||
self.connector_url = "https://connectors.googleapis.com"
|
||||
self.service_account_json = service_account_json
|
||||
self.credential_cache = None
|
||||
|
||||
def get_connection_details(self) -> Dict[str, Any]:
|
||||
"""Retrieves service details (service name and host) for a given connection.
|
||||
|
||||
Also returns if auth override is enabled for the connection.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (service_name, host).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}?view=BASIC"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
connection_data = response.json()
|
||||
service_name = connection_data.get("serviceDirectory", "")
|
||||
host = connection_data.get("host", "")
|
||||
if host:
|
||||
service_name = connection_data.get("tlsServiceDirectory", "")
|
||||
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
||||
return {
|
||||
"serviceName": service_name,
|
||||
"host": host,
|
||||
"authOverrideEnabled": auth_override_enabled,
|
||||
}
|
||||
|
||||
def get_entity_schema_and_operations(
|
||||
self, entity: str
|
||||
) -> Tuple[Dict[str, Any], List[str]]:
|
||||
"""Retrieves the JSON schema for a given entity in a connection.
|
||||
|
||||
Args:
|
||||
entity (str): The entity name.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (schema, operations).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getEntityType?entityId={entity}"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
operation_id = response.json().get("name")
|
||||
|
||||
if not operation_id:
|
||||
raise ValueError(
|
||||
f"Failed to get entity schema and operations for entity: {entity}"
|
||||
)
|
||||
|
||||
operation_response = self._poll_operation(operation_id)
|
||||
|
||||
schema = operation_response.get("response", {}).get("jsonSchema", {})
|
||||
operations = operation_response.get("response", {}).get("operations", [])
|
||||
return schema, operations
|
||||
|
||||
def get_action_schema(self, action: str) -> Dict[str, Any]:
|
||||
"""Retrieves the input and output JSON schema for a given action in a connection.
|
||||
|
||||
Args:
|
||||
action (str): The action name.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (input_schema, output_schema).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getAction?actionId={action}"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
operation_id = response.json().get("name")
|
||||
|
||||
if not operation_id:
|
||||
raise ValueError(f"Failed to get action schema for action: {action}")
|
||||
|
||||
operation_response = self._poll_operation(operation_id)
|
||||
|
||||
input_schema = operation_response.get("response", {}).get(
|
||||
"inputJsonSchema", {}
|
||||
)
|
||||
output_schema = operation_response.get("response", {}).get(
|
||||
"outputJsonSchema", {}
|
||||
)
|
||||
description = operation_response.get("response", {}).get("description", "")
|
||||
display_name = operation_response.get("response", {}).get("displayName", "")
|
||||
return {
|
||||
"inputSchema": input_schema,
|
||||
"outputSchema": output_schema,
|
||||
"description": description,
|
||||
"displayName": display_name,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_connector_base_spec() -> Dict[str, Any]:
|
||||
return {
|
||||
"openapi": "3.0.1",
|
||||
"info": {
|
||||
"title": "ExecuteConnection",
|
||||
"description": "This tool can execute a query on connection",
|
||||
"version": "4",
|
||||
},
|
||||
"servers": [{"url": "https://integrations.googleapis.com"}],
|
||||
"security": [
|
||||
{"google_auth": ["https://www.googleapis.com/auth/cloud-platform"]}
|
||||
],
|
||||
"paths": {},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"default": "LIST_ENTITIES",
|
||||
"description": (
|
||||
"Operation to execute. Possible values are"
|
||||
" LIST_ENTITIES, GET_ENTITY, CREATE_ENTITY,"
|
||||
" UPDATE_ENTITY, DELETE_ENTITY in case of entities."
|
||||
" EXECUTE_ACTION in case of actions. and EXECUTE_QUERY"
|
||||
" in case of custom queries."
|
||||
),
|
||||
},
|
||||
"entityId": {
|
||||
"type": "string",
|
||||
"description": "Name of the entity",
|
||||
},
|
||||
"connectorInputPayload": {"type": "object"},
|
||||
"filterClause": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "WHERE clause in SQL query",
|
||||
},
|
||||
"pageSize": {
|
||||
"type": "integer",
|
||||
"default": 50,
|
||||
"description": (
|
||||
"Number of entities to return in the response"
|
||||
),
|
||||
},
|
||||
"pageToken": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": (
|
||||
"Page token to return the next page of entities"
|
||||
),
|
||||
},
|
||||
"connectionName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": (
|
||||
"Connection resource name to run the query for"
|
||||
),
|
||||
},
|
||||
"serviceName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Service directory for the connection",
|
||||
},
|
||||
"host": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Host name incase of tls service directory",
|
||||
},
|
||||
"entity": {
|
||||
"type": "string",
|
||||
"default": "Issues",
|
||||
"description": "Entity to run the query for",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"default": "ExecuteCustomQuery",
|
||||
"description": "Action to run the query for",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Custom Query to execute on the connection",
|
||||
},
|
||||
"dynamicAuthConfig": {
|
||||
"type": "object",
|
||||
"default": {},
|
||||
"description": "Dynamic auth config for the connection",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"default": 120,
|
||||
"description": (
|
||||
"Timeout in seconds for execution of custom query"
|
||||
),
|
||||
},
|
||||
"connectorOutputPayload": {"type": "object"},
|
||||
"nextPageToken": {"type": "string"},
|
||||
"execute-connector_Response": {
|
||||
"required": ["connectorOutputPayload"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connectorOutputPayload": {
|
||||
"$ref": (
|
||||
"#/components/schemas/connectorOutputPayload"
|
||||
)
|
||||
},
|
||||
"nextPageToken": {
|
||||
"$ref": "#/components/schemas/nextPageToken"
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"securitySchemes": {
|
||||
"google_auth": {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"implicit": {
|
||||
"authorizationUrl": (
|
||||
"https://accounts.google.com/o/oauth2/auth"
|
||||
),
|
||||
"scopes": {
|
||||
"https://www.googleapis.com/auth/cloud-platform": (
|
||||
"Auth for google cloud services"
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_action_operation(
|
||||
action: str,
|
||||
operation: str,
|
||||
action_display_name: str,
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
description = (
|
||||
f"Use this tool with" f' action = "{action}" and'
|
||||
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
if operation == "EXECUTE_QUERY":
|
||||
description = (
|
||||
(f"Use this tool with" f' action = "{action}" and')
|
||||
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
||||
" different value otherwise. If user provides a query in natural"
|
||||
" language, convert it to SQL query and then execute it using the"
|
||||
" tool."
|
||||
)
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"{action_display_name}",
|
||||
"description": f"{description} {tool_instructions}",
|
||||
"operationId": f"{tool_name}_{action_display_name}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/{action_display_name}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/{action_display_name}_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_operation(
|
||||
entity: str,
|
||||
schema_as_string: str = "",
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"List {entity}",
|
||||
"description": (
|
||||
f"Returns all entities of type {entity}. Use this tool with"
|
||||
+ f' entity = "{entity}" and'
|
||||
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
|
||||
" from"
|
||||
+ ' user. Always use ""'
|
||||
+ ' as filter clause and ""'
|
||||
+ " as page token and 50 as page size until user specifies a"
|
||||
" different value otherwise. Use single quotes for strings in"
|
||||
f" filter clause. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_list_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/list_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"description": (
|
||||
f"Returns a list of {entity} of json"
|
||||
f" schema: {schema_as_string}"
|
||||
),
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_operation(
|
||||
entity: str,
|
||||
schema_as_string: str = "",
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Get {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Returns the details of the {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_get_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": f"#/components/schemas/get_{entity}_Request"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"description": (
|
||||
f"Returns {entity} of json schema:"
|
||||
f" {schema_as_string}"
|
||||
),
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Create {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Creates a new entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Follow the schema of the entity provided in the"
|
||||
f" instructions to create {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_create_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/create_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Update {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Updates an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Use entityId to uniquely identify the entity to"
|
||||
" update. Follow the schema of the entity provided in the"
|
||||
f" instructions to update {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_update_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/update_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Delete {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Deletes an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_delete_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/delete_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_operation_request(entity: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"connectorInputPayload",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
|
||||
},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_operation_request(entity: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"connectorInputPayload",
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
|
||||
},
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||
"pageSize": {"$ref": "#/components/schemas/pageSize"},
|
||||
"pageToken": {"$ref": "#/components/schemas/pageToken"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def action_request(action: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"action",
|
||||
"connectorInputPayload",
|
||||
],
|
||||
"properties": {
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"action": {"$ref": "#/components/schemas/action"},
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def action_response(action: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connectorOutputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorOutputPayload_{action}"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def execute_custom_query_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"action",
|
||||
"query",
|
||||
"timeout",
|
||||
"pageSize",
|
||||
],
|
||||
"properties": {
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"action": {"$ref": "#/components/schemas/action"},
|
||||
"query": {"$ref": "#/components/schemas/query"},
|
||||
"timeout": {"$ref": "#/components/schemas/timeout"},
|
||||
"pageSize": {"$ref": "#/components/schemas/pageSize"},
|
||||
},
|
||||
}
|
||||
|
||||
def connector_payload(self, json_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self._convert_json_schema_to_openapi_schema(json_schema)
|
||||
|
||||
def _convert_json_schema_to_openapi_schema(self, json_schema):
|
||||
"""Converts a JSON schema dictionary to an OpenAPI schema dictionary, handling variable types, properties, items, nullable, and description.
|
||||
|
||||
Args:
|
||||
json_schema (dict): The input JSON schema dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The converted OpenAPI schema dictionary.
|
||||
"""
|
||||
openapi_schema = {}
|
||||
|
||||
if "description" in json_schema:
|
||||
openapi_schema["description"] = json_schema["description"]
|
||||
|
||||
if "type" in json_schema:
|
||||
if isinstance(json_schema["type"], list):
|
||||
if "null" in json_schema["type"]:
|
||||
openapi_schema["nullable"] = True
|
||||
other_types = [t for t in json_schema["type"] if t != "null"]
|
||||
if other_types:
|
||||
openapi_schema["type"] = other_types[0]
|
||||
else:
|
||||
openapi_schema["type"] = json_schema["type"][0]
|
||||
else:
|
||||
openapi_schema["type"] = json_schema["type"]
|
||||
|
||||
if openapi_schema.get("type") == "object" and "properties" in json_schema:
|
||||
openapi_schema["properties"] = {}
|
||||
for prop_name, prop_schema in json_schema["properties"].items():
|
||||
openapi_schema["properties"][prop_name] = (
|
||||
self._convert_json_schema_to_openapi_schema(prop_schema)
|
||||
)
|
||||
|
||||
elif openapi_schema.get("type") == "array" and "items" in json_schema:
|
||||
if isinstance(json_schema["items"], list):
|
||||
openapi_schema["items"] = [
|
||||
self._convert_json_schema_to_openapi_schema(item)
|
||||
for item in json_schema["items"]
|
||||
]
|
||||
else:
|
||||
openapi_schema["items"] = self._convert_json_schema_to_openapi_schema(
|
||||
json_schema["items"]
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account_json:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
|
||||
def _execute_api_call(self, url):
|
||||
"""Executes an API call to the given URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL to call.
|
||||
|
||||
Returns:
|
||||
requests.Response: The response object from the API call.
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except google.auth.exceptions.DefaultCredentialsError as e:
|
||||
raise PermissionError(f"Credentials error: {e}") from e
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
if (
|
||||
"404" in str(e)
|
||||
or "Not found" in str(e)
|
||||
or "400" in str(e)
|
||||
or "Bad request" in str(e)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid request. Please check the provided"
|
||||
f" values of project({self.project}), location({self.location}),"
|
||||
f" connection({self.connection})."
|
||||
) from e
|
||||
raise ValueError(f"Request error: {e}") from e
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An unexpected error occurred: {e}") from e
|
||||
|
||||
def _poll_operation(self, operation_id: str) -> Dict[str, Any]:
|
||||
"""Polls an operation until it is done.
|
||||
|
||||
Args:
|
||||
operation_id: The ID of the operation to poll.
|
||||
|
||||
Returns:
|
||||
The final response of the operation.
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
operation_done: bool = False
|
||||
operation_response: Dict[str, Any] = {}
|
||||
while not operation_done:
|
||||
get_operation_url = f"{self.connector_url}/v1/{operation_id}"
|
||||
response = self._execute_api_call(get_operation_url)
|
||||
operation_response = response.json()
|
||||
operation_done = operation_response.get("done", False)
|
||||
time.sleep(1)
|
||||
return operation_response
|
||||
@@ -0,0 +1,253 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
import google.auth.transport.requests
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class IntegrationClient:
|
||||
"""A client for interacting with Google Cloud Application Integration.
|
||||
|
||||
This class provides methods for retrieving OpenAPI spec for an integration or
|
||||
a connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
integration: Optional[str] = None,
|
||||
trigger: Optional[str] = None,
|
||||
connection: Optional[str] = None,
|
||||
entity_operations: Optional[dict[str, list[str]]] = None,
|
||||
actions: Optional[list[str]] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationClient.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., us-central1).
|
||||
integration: The integration name.
|
||||
trigger: The trigger ID for the integration.
|
||||
connection: The connection name.
|
||||
entity_operations: A dictionary mapping entity names to a list of
|
||||
operations (e.g., LIST, CREATE, UPDATE, DELETE, GET).
|
||||
actions: List of actions.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
connection details.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.integration = integration
|
||||
self.trigger = trigger
|
||||
self.connection = connection
|
||||
self.entity_operations = (
|
||||
entity_operations if entity_operations is not None else {}
|
||||
)
|
||||
self.actions = actions if actions is not None else []
|
||||
self.service_account_json = service_account_json
|
||||
self.credential_cache = None
|
||||
|
||||
def get_openapi_spec_for_integration(self):
|
||||
"""Gets the OpenAPI spec for the integration.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI spec as a dictionary.
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
url = f"https://{self.location}-integrations.googleapis.com/v1/projects/{self.project}/locations/{self.location}:generateOpenApiSpec"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
data = {
|
||||
"apiTriggerResources": [
|
||||
{
|
||||
"integrationResource": self.integration,
|
||||
"triggerId": [self.trigger],
|
||||
},
|
||||
],
|
||||
"fileFormat": "JSON",
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
spec = response.json().get("openApiSpec", {})
|
||||
return json.loads(spec)
|
||||
except google.auth.exceptions.DefaultCredentialsError as e:
|
||||
raise PermissionError(f"Credentials error: {e}") from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
if (
|
||||
"404" in str(e)
|
||||
or "Not found" in str(e)
|
||||
or "400" in str(e)
|
||||
or "Bad request" in str(e)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid request. Please check the provided values of"
|
||||
f" project({self.project}), location({self.location}),"
|
||||
f" integration({self.integration}) and trigger({self.trigger})."
|
||||
) from e
|
||||
raise ValueError(f"Request error: {e}") from e
|
||||
except Exception as e:
|
||||
raise Exception(f"An unexpected error occurred: {e}") from e
|
||||
|
||||
def get_openapi_spec_for_connection(self, tool_name="", tool_instructions=""):
|
||||
"""Gets the OpenAPI spec for the connection.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI spec as a dictionary.
|
||||
Raises:
|
||||
ValueError: If there's an error retrieving the OpenAPI spec.
|
||||
PermissionError: If there are credential issues.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
# Application Integration needs to be provisioned in the same region as connection and an integration with name "ExecuteConnection" and trigger "api_trigger/ExecuteConnection" should be created as per the documentation.
|
||||
integration_name = "ExecuteConnection"
|
||||
connections_client = ConnectionsClient(
|
||||
self.project,
|
||||
self.location,
|
||||
self.connection,
|
||||
self.service_account_json,
|
||||
)
|
||||
if not self.entity_operations and not self.actions:
|
||||
raise ValueError(
|
||||
"No entity operations or actions provided. Please provide at least"
|
||||
" one of them."
|
||||
)
|
||||
connector_spec = connections_client.get_connector_base_spec()
|
||||
for entity, operations in self.entity_operations.items():
|
||||
schema, supported_operations = (
|
||||
connections_client.get_entity_schema_and_operations(entity)
|
||||
)
|
||||
if not operations:
|
||||
operations = supported_operations
|
||||
json_schema_as_string = json.dumps(schema)
|
||||
entity_lower = entity
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorInputPayload_{entity_lower}"
|
||||
] = connections_client.connector_payload(schema)
|
||||
for operation in operations:
|
||||
operation_lower = operation.lower()
|
||||
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{operation_lower}_{entity_lower}"
|
||||
if operation_lower == "create":
|
||||
connector_spec["paths"][path] = connections_client.create_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"create_{entity_lower}_Request"
|
||||
] = connections_client.create_operation_request(entity_lower)
|
||||
elif operation_lower == "update":
|
||||
connector_spec["paths"][path] = connections_client.update_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"update_{entity_lower}_Request"
|
||||
] = connections_client.update_operation_request(entity_lower)
|
||||
elif operation_lower == "delete":
|
||||
connector_spec["paths"][path] = connections_client.delete_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"delete_{entity_lower}_Request"
|
||||
] = connections_client.delete_operation_request()
|
||||
elif operation_lower == "list":
|
||||
connector_spec["paths"][path] = connections_client.list_operation(
|
||||
entity_lower, json_schema_as_string, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"list_{entity_lower}_Request"
|
||||
] = connections_client.list_operation_request()
|
||||
elif operation_lower == "get":
|
||||
connector_spec["paths"][path] = connections_client.get_operation(
|
||||
entity_lower, json_schema_as_string, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"get_{entity_lower}_Request"
|
||||
] = connections_client.get_operation_request()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation} for entity: {entity}"
|
||||
)
|
||||
for action in self.actions:
|
||||
action_details = connections_client.get_action_schema(action)
|
||||
input_schema = action_details["inputSchema"]
|
||||
output_schema = action_details["outputSchema"]
|
||||
action_display_name = action_details["displayName"]
|
||||
operation = "EXECUTE_ACTION"
|
||||
if action == "ExecuteCustomQuery":
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action}_Request"
|
||||
] = connections_client.execute_custom_query_request()
|
||||
operation = "EXECUTE_QUERY"
|
||||
else:
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action_display_name}_Request"
|
||||
] = connections_client.action_request(action_display_name)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorInputPayload_{action_display_name}"
|
||||
] = connections_client.connector_payload(input_schema)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorOutputPayload_{action_display_name}"
|
||||
] = connections_client.connector_payload(output_schema)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action_display_name}_Response"
|
||||
] = connections_client.action_response(action_display_name)
|
||||
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{action}"
|
||||
connector_spec["paths"][path] = connections_client.get_action_operation(
|
||||
action, operation, action_display_name, tool_name, tool_instructions
|
||||
)
|
||||
return connector_spec
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account or using default credentials.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account_json:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
144
src/google/adk/tools/base_tool.py
Normal file
144
src/google/adk/tools/base_tool.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""The base class for all tools."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str
|
||||
"""The description of the tool."""
|
||||
|
||||
is_long_running: bool = False
|
||||
"""Whether the tool is a long running operation, which typically returns a
|
||||
resource id first and finishes the operation later."""
|
||||
|
||||
def __init__(self, *, name, description, is_long_running: bool = False):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.is_long_running = is_long_running
|
||||
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
|
||||
|
||||
NOTE
|
||||
- Required if subclass uses the default implementation of
|
||||
`process_llm_request` to add function declaration to LLM request.
|
||||
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
|
||||
Gemini.
|
||||
|
||||
Returns:
|
||||
The FunctionDeclaration of this tool, or None if it doesn't need to be
|
||||
added to LlmRequest.config.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
"""Runs the tool with the given arguments and context.
|
||||
|
||||
NOTE
|
||||
- Required if this tool needs to run at the client side.
|
||||
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
|
||||
Gemini.
|
||||
|
||||
Args:
|
||||
args: The LLM-filled arguments.
|
||||
ctx: The context of the tool.
|
||||
|
||||
Returns:
|
||||
The result of running the tool.
|
||||
"""
|
||||
raise NotImplementedError(f'{type(self)} is not implemented')
|
||||
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
"""Processes the outgoing LLM request for this tool.
|
||||
|
||||
Use cases:
|
||||
- Most common use case is adding this tool to the LLM request.
|
||||
- Some tools may just preprocess the LLM request before it's sent out.
|
||||
|
||||
Args:
|
||||
tool_context: The context of the tool.
|
||||
llm_request: The outgoing LLM request, mutable this method.
|
||||
"""
|
||||
if (function_declaration := self._get_declaration()) is None:
|
||||
return
|
||||
|
||||
llm_request.tools_dict[self.name] = self
|
||||
if tool_with_function_declarations := _find_tool_with_function_declarations(
|
||||
llm_request
|
||||
):
|
||||
if tool_with_function_declarations.function_declarations is None:
|
||||
tool_with_function_declarations.function_declarations = []
|
||||
tool_with_function_declarations.function_declarations.append(
|
||||
function_declaration
|
||||
)
|
||||
else:
|
||||
llm_request.config = (
|
||||
types.GenerateContentConfig()
|
||||
if not llm_request.config
|
||||
else llm_request.config
|
||||
)
|
||||
llm_request.config.tools = (
|
||||
[] if not llm_request.config.tools else llm_request.config.tools
|
||||
)
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(function_declarations=[function_declaration])
|
||||
)
|
||||
|
||||
@property
|
||||
def _api_variant(self) -> str:
|
||||
use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
|
||||
'true',
|
||||
'1',
|
||||
]
|
||||
return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI'
|
||||
|
||||
|
||||
def _find_tool_with_function_declarations(
|
||||
llm_request: LlmRequest,
|
||||
) -> Optional[types.Tool]:
|
||||
# TODO: add individual tool with declaration and merge in google_llm.py
|
||||
if not llm_request.config or not llm_request.config.tools:
|
||||
return None
|
||||
|
||||
return next(
|
||||
(
|
||||
tool
|
||||
for tool in llm_request.config.tools
|
||||
if isinstance(tool, types.Tool) and tool.function_declarations
|
||||
),
|
||||
None,
|
||||
)
|
||||
59
src/google/adk/tools/built_in_code_execution_tool.py
Normal file
59
src/google/adk/tools/built_in_code_execution_tool.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class BuiltInCodeExecutionTool(BaseTool):
|
||||
"""A built-in code execution tool that is automatically invoked by Gemini 2 models.
|
||||
|
||||
This tool operates internally within the model and does not require or perform
|
||||
local code execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Name and description are not used because this is a model built-in tool.
|
||||
super().__init__(name='code_execution', description='code_execution')
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
if llm_request.model and llm_request.model.startswith('gemini-2'):
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.tools = llm_request.config.tools or []
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(code_execution=types.ToolCodeExecution())
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Code execution tool is not supported for model {llm_request.model}'
|
||||
)
|
||||
|
||||
|
||||
built_in_code_execution = BuiltInCodeExecutionTool()
|
||||
72
src/google/adk/tools/crewai_tool.py
Normal file
72
src/google/adk/tools/crewai_tool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
try:
|
||||
from crewai.tools import BaseTool as CrewaiBaseTool
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"Crewai Tools require Python 3.10+. Please upgrade your Python version."
|
||||
) from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"Crewai Tools require pip install 'google-adk[extensions]'."
|
||||
) from e
|
||||
|
||||
|
||||
class CrewaiTool(FunctionTool):
|
||||
"""Use this class to wrap a CrewAI tool.
|
||||
|
||||
If the original tool name and description are not suitable, you can override
|
||||
them in the constructor.
|
||||
"""
|
||||
|
||||
tool: CrewaiBaseTool
|
||||
"""The wrapped CrewAI tool."""
|
||||
|
||||
def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
|
||||
super().__init__(tool.run)
|
||||
self.tool = tool
|
||||
if name:
|
||||
self.name = name
|
||||
elif tool.name:
|
||||
# Right now, CrewAI tool name contains white spaces. White spaces are
|
||||
# not supported in our framework. So we replace them with "_".
|
||||
self.name = tool.name.replace(" ", "_").lower()
|
||||
if description:
|
||||
self.description = description
|
||||
elif tool.description:
|
||||
self.description = tool.description
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
|
||||
False,
|
||||
self.name,
|
||||
self.description,
|
||||
self.func,
|
||||
self.tool.args_schema.model_json_schema(),
|
||||
)
|
||||
return function_declaration
|
||||
62
src/google/adk/tools/example_tool.py
Normal file
62
src/google/adk/tools/example_tool.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from typing_extensions import override
|
||||
|
||||
from ..examples import example_util
|
||||
from ..examples.base_example_provider import BaseExampleProvider
|
||||
from ..examples.example import Example
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class ExampleTool(BaseTool):
|
||||
"""A tool that adds (few-shot) examples to the LLM request.
|
||||
|
||||
Attributes:
|
||||
examples: The examples to add to the LLM request.
|
||||
"""
|
||||
|
||||
def __init__(self, examples: Union[list[Example], BaseExampleProvider]):
|
||||
# Name and description are not used because this tool only changes
|
||||
# llm_request.
|
||||
super().__init__(name='example_tool', description='example tool')
|
||||
self.examples = (
|
||||
TypeAdapter(list[Example]).validate_python(examples)
|
||||
if isinstance(examples, list)
|
||||
else examples
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
parts = tool_context.user_content.parts
|
||||
if not parts or not parts[0].text:
|
||||
return
|
||||
|
||||
llm_request.append_instructions([
|
||||
example_util.build_example_si(
|
||||
self.examples, parts[0].text, llm_request.model
|
||||
)
|
||||
])
|
||||
23
src/google/adk/tools/exit_loop_tool.py
Normal file
23
src/google/adk/tools/exit_loop_tool.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
def exit_loop(tool_context: ToolContext):
|
||||
"""Exits the loop.
|
||||
|
||||
Call this function only when you are instructed to do so.
|
||||
"""
|
||||
tool_context.actions.escalate = True
|
||||
307
src/google/adk/tools/function_parameter_parse_util.py
Normal file
307
src/google/adk/tools/function_parameter_parse_util.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import types as typing_types
|
||||
from typing import _GenericAlias
|
||||
from typing import Any
|
||||
from typing import get_args
|
||||
from typing import get_origin
|
||||
from typing import Literal
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
import pydantic
|
||||
|
||||
_py_builtin_type_to_schema_type = {
|
||||
str: types.Type.STRING,
|
||||
int: types.Type.INTEGER,
|
||||
float: types.Type.NUMBER,
|
||||
bool: types.Type.BOOLEAN,
|
||||
list: types.Type.ARRAY,
|
||||
dict: types.Type.OBJECT,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_builtin_primitive_or_compound(
|
||||
annotation: inspect.Parameter.annotation,
|
||||
) -> bool:
|
||||
return annotation in _py_builtin_type_to_schema_type.keys()
|
||||
|
||||
|
||||
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
||||
if schema.any_of:
|
||||
raise ValueError(
|
||||
'AnyOf is not supported in function declaration schema for Google AI.'
|
||||
)
|
||||
|
||||
|
||||
def _update_for_default_if_mldev(schema: types.Schema):
|
||||
if schema.default is not None:
|
||||
# TODO(kech): Remove this walkaround once mldev supports default value.
|
||||
schema.default = None
|
||||
logger.warning(
|
||||
'Default value is not supported in function declaration schema for'
|
||||
' Google AI.'
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
|
||||
if not variant == 'VERTEX_AI':
|
||||
_raise_for_any_of_if_mldev(schema)
|
||||
_update_for_default_if_mldev(schema)
|
||||
|
||||
|
||||
def _is_default_value_compatible(
|
||||
default_value: Any, annotation: inspect.Parameter.annotation
|
||||
) -> bool:
|
||||
# None type is expected to be handled external to this function
|
||||
if _is_builtin_primitive_or_compound(annotation):
|
||||
return isinstance(default_value, annotation)
|
||||
|
||||
if (
|
||||
isinstance(annotation, _GenericAlias)
|
||||
or isinstance(annotation, typing_types.GenericAlias)
|
||||
or isinstance(annotation, typing_types.UnionType)
|
||||
):
|
||||
origin = get_origin(annotation)
|
||||
if origin in (Union, typing_types.UnionType):
|
||||
return any(
|
||||
_is_default_value_compatible(default_value, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
|
||||
if origin is dict:
|
||||
return isinstance(default_value, dict)
|
||||
|
||||
if origin is list:
|
||||
if not isinstance(default_value, list):
|
||||
return False
|
||||
# most tricky case, element in list is union type
|
||||
# need to apply any logic within all
|
||||
# see test case test_generic_alias_complex_array_with_default_value
|
||||
# a: typing.List[int | str | float | bool]
|
||||
# default_value: [1, 'a', 1.1, True]
|
||||
return all(
|
||||
any(
|
||||
_is_default_value_compatible(item, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
for item in default_value
|
||||
)
|
||||
|
||||
if origin is Literal:
|
||||
return default_value in get_args(annotation)
|
||||
|
||||
# return False for any other unrecognized annotation
|
||||
# let caller handle the raise
|
||||
return False
|
||||
|
||||
|
||||
def _parse_schema_from_parameter(
|
||||
variant: str, param: inspect.Parameter, func_name: str
|
||||
) -> types.Schema:
|
||||
"""parse schema from parameter.
|
||||
|
||||
from the simplest case to the most complex case.
|
||||
"""
|
||||
schema = types.Schema()
|
||||
default_value_error_msg = (
|
||||
f'Default value {param.default} of parameter {param} of function'
|
||||
f' {func_name} is not compatible with the parameter annotation'
|
||||
f' {param.annotation}.'
|
||||
)
|
||||
if _is_builtin_primitive_or_compound(param.annotation):
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if (
|
||||
get_origin(param.annotation) is Union
|
||||
# only parse simple UnionType, example int | str | float | bool
|
||||
# complex types.UnionType will be invoked in raise branch
|
||||
and all(
|
||||
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
||||
for arg in get_args(param.annotation)
|
||||
)
|
||||
):
|
||||
schema.type = types.Type.OBJECT
|
||||
schema.any_of = []
|
||||
unique_types = set()
|
||||
for arg in get_args(param.annotation):
|
||||
if arg.__name__ == 'NoneType': # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: list | None -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
||||
param.annotation, typing_types.GenericAlias
|
||||
):
|
||||
origin = get_origin(param.annotation)
|
||||
args = get_args(param.annotation)
|
||||
if origin is dict:
|
||||
schema.type = types.Type.OBJECT
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is Literal:
|
||||
if not all(isinstance(arg, str) for arg in args):
|
||||
raise ValueError(
|
||||
f'Literal type {param.annotation} must be a list of strings.'
|
||||
)
|
||||
schema.type = types.Type.STRING
|
||||
schema.enum = list(args)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is list:
|
||||
schema.type = types.Type.ARRAY
|
||||
schema.items = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=args[0],
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is Union:
|
||||
schema.any_of = []
|
||||
schema.type = types.Type.OBJECT
|
||||
unique_types = set()
|
||||
for arg in args:
|
||||
if arg.__name__ == 'NoneType': # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=arg,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
len(param.annotation.__args__) == 2
|
||||
and type(None) in param.annotation.__args__
|
||||
): # Optional type
|
||||
for optional_arg in param.annotation.__args__:
|
||||
if (
|
||||
hasattr(optional_arg, '__origin__')
|
||||
and optional_arg.__origin__ is list
|
||||
):
|
||||
# Optional type with list, for example Optional[list[str]]
|
||||
schema.items = schema_in_any_of.items
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not None
|
||||
and param.default is not inspect.Parameter.empty
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
# all other generic alias will be invoked in raise branch
|
||||
if (
|
||||
inspect.isclass(param.annotation)
|
||||
# for user defined class, we only support pydantic model
|
||||
and issubclass(param.annotation, pydantic.BaseModel)
|
||||
):
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
schema.default = param.default
|
||||
schema.type = types.Type.OBJECT
|
||||
schema.properties = {}
|
||||
for field_name, field_info in param.annotation.model_fields.items():
|
||||
schema.properties[field_name] = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=field_info.annotation,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
raise ValueError(
|
||||
f'Failed to parse the parameter {param} of function {func_name} for'
|
||||
' automatic function calling.Automatic function calling works best with'
|
||||
' simpler function signature schema,consider manually parse your'
|
||||
f' function declaration for function {func_name}.'
|
||||
)
|
||||
|
||||
|
||||
def _get_required_fields(schema: types.Schema) -> list[str]:
|
||||
if not schema.properties:
|
||||
return
|
||||
return [
|
||||
field_name
|
||||
for field_name, field_schema in schema.properties.items()
|
||||
if not field_schema.nullable and field_schema.default is None
|
||||
]
|
||||
87
src/google/adk/tools/function_tool.py
Normal file
87
src/google/adk/tools/function_tool.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ._automatic_function_calling_util import build_function_declaration
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
class FunctionTool(BaseTool):
|
||||
"""A tool that wraps a user-defined Python function.
|
||||
|
||||
Attributes:
|
||||
func: The function to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[..., Any]):
|
||||
super().__init__(name=func.__name__, description=func.__doc__)
|
||||
self.func = func
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
function_decl = types.FunctionDeclaration.model_validate(
|
||||
build_function_declaration(
|
||||
func=self.func,
|
||||
# The model doesn't understand the function context.
|
||||
# input_stream is for streaming tool
|
||||
ignore_params=['tool_context', 'input_stream'],
|
||||
variant=self._api_variant,
|
||||
)
|
||||
)
|
||||
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**args_to_call) or {}
|
||||
else:
|
||||
return self.func(**args_to_call) or {}
|
||||
|
||||
# TODO(hangfei): fix call live for function stream.
|
||||
async def _call_live(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
invocation_context,
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if (
|
||||
self.name in invocation_context.active_streaming_tools
|
||||
and invocation_context.active_streaming_tools[self.name].stream
|
||||
):
|
||||
args_to_call['input_stream'] = invocation_context.active_streaming_tools[
|
||||
self.name
|
||||
].stream
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
async for item in self.func(**args_to_call):
|
||||
yield item
|
||||
28
src/google/adk/tools/get_user_choice_tool.py
Normal file
28
src/google/adk/tools/get_user_choice_tool.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from .long_running_tool import LongRunningFunctionTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
def get_user_choice(
|
||||
options: list[str], tool_context: ToolContext
|
||||
) -> Optional[str]:
|
||||
"""Provides the options to the user and asks them to choose one."""
|
||||
tool_context.actions.skip_summarization = True
|
||||
return None
|
||||
|
||||
|
||||
get_user_choice_tool = LongRunningFunctionTool(func=get_user_choice)
|
||||
14
src/google/adk/tools/google_api_tool/__init__.py
Normal file
14
src/google/adk/tools/google_api_tool/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .google_api_tool_sets import calendar_tool_set
|
||||
59
src/google/adk/tools/google_api_tool/google_api_tool.py
Normal file
59
src/google/adk/tools/google_api_tool/google_api_tool.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from ...auth import AuthCredential
|
||||
from ...auth import AuthCredentialTypes
|
||||
from ...auth import OAuth2Auth
|
||||
from .. import BaseTool
|
||||
from ..openapi_tool import RestApiTool
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
|
||||
class GoogleApiTool(BaseTool):
|
||||
|
||||
def __init__(self, rest_api_tool: RestApiTool):
|
||||
super().__init__(
|
||||
name=rest_api_tool.name,
|
||||
description=rest_api_tool.description,
|
||||
is_long_running=rest_api_tool.is_long_running,
|
||||
)
|
||||
self.rest_api_tool = rest_api_tool
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
return self.rest_api_tool._get_declaration()
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
return await self.rest_api_tool.run_async(
|
||||
args=args, tool_context=tool_context
|
||||
)
|
||||
|
||||
def configure_auth(self, client_id: str, client_secret: str):
|
||||
self.rest_api_tool.auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
),
|
||||
)
|
||||
107
src/google/adk/tools/google_api_tool/google_api_tool_set.py
Normal file
107
src/google/adk/tools/google_api_tool/google_api_tool_set.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from ...auth import OpenIdConnectWithConfig
|
||||
from ..openapi_tool import OpenAPIToolset
|
||||
from ..openapi_tool import RestApiTool
|
||||
from .google_api_tool import GoogleApiTool
|
||||
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
|
||||
|
||||
class GoogleApiToolSet:
|
||||
|
||||
def __init__(self, tools: List[RestApiTool]):
|
||||
self.tools: Final[List[GoogleApiTool]] = [
|
||||
GoogleApiTool(tool) for tool in tools
|
||||
]
|
||||
|
||||
def get_tools(self) -> List[GoogleApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
|
||||
@staticmethod
|
||||
def _load_tool_set_with_oidc_auth(
|
||||
spec_file: str = None,
|
||||
spec_dict: Dict[str, Any] = None,
|
||||
scopes: list[str] = None,
|
||||
) -> Optional[OpenAPIToolset]:
|
||||
spec_str = None
|
||||
if spec_file:
|
||||
# Get the frame of the caller
|
||||
caller_frame = inspect.stack()[1]
|
||||
# Get the filename of the caller
|
||||
caller_filename = caller_frame.filename
|
||||
# Get the directory of the caller
|
||||
caller_dir = os.path.dirname(os.path.abspath(caller_filename))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(caller_dir, spec_file)
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
spec_str = file.read()
|
||||
tool_set = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
spec_str=spec_str,
|
||||
spec_str_type='yaml',
|
||||
auth_scheme=OpenIdConnectWithConfig(
|
||||
authorization_endpoint=(
|
||||
'https://accounts.google.com/o/oauth2/v2/auth'
|
||||
),
|
||||
token_endpoint='https://oauth2.googleapis.com/token',
|
||||
userinfo_endpoint=(
|
||||
'https://openidconnect.googleapis.com/v1/userinfo'
|
||||
),
|
||||
revocation_endpoint='https://oauth2.googleapis.com/revoke',
|
||||
token_endpoint_auth_methods_supported=[
|
||||
'client_secret_post',
|
||||
'client_secret_basic',
|
||||
],
|
||||
grant_types_supported=['authorization_code'],
|
||||
scopes=scopes,
|
||||
),
|
||||
)
|
||||
return tool_set
|
||||
|
||||
def configure_auth(self, client_id: str, client_secret: str):
|
||||
for tool in self.tools:
|
||||
tool.configure_auth(client_id, client_secret)
|
||||
|
||||
@classmethod
|
||||
def load_tool_set(
|
||||
cl: Type['GoogleApiToolSet'],
|
||||
api_name: str,
|
||||
api_version: str,
|
||||
) -> 'GoogleApiToolSet':
|
||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||
scope = list(
|
||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||
'authorizationCode'
|
||||
]['scopes'].keys()
|
||||
)[0]
|
||||
return cl(
|
||||
cl._load_tool_set_with_oidc_auth(
|
||||
spec_dict=spec_dict, scopes=[scope]
|
||||
).get_tools()
|
||||
)
|
||||
55
src/google/adk/tools/google_api_tool/google_api_tool_sets.py
Normal file
55
src/google/adk/tools/google_api_tool/google_api_tool_sets.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .google_api_tool_set import GoogleApiToolSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="calendar",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="bigquery",
|
||||
api_version="v2",
|
||||
)
|
||||
|
||||
gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="gmail",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="youtube",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="slides",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="sheets",
|
||||
api_version="v4",
|
||||
)
|
||||
|
||||
docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="docs",
|
||||
api_version="v1",
|
||||
)
|
||||
@@ -0,0 +1,521 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
# Google API client
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import Resource
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleApiToOpenApiConverter:
|
||||
"""Converts Google API Discovery documents to OpenAPI v3 format."""
|
||||
|
||||
def __init__(self, api_name: str, api_version: str):
|
||||
"""Initialize the converter with the API name and version.
|
||||
|
||||
Args:
|
||||
api_name: The name of the Google API (e.g., "calendar")
|
||||
api_version: The version of the API (e.g., "v3")
|
||||
"""
|
||||
self.api_name = api_name
|
||||
self.api_version = api_version
|
||||
self.google_api_resource = None
|
||||
self.google_api_spec = None
|
||||
self.openapi_spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {},
|
||||
"servers": [],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}, "securitySchemes": {}},
|
||||
}
|
||||
|
||||
def fetch_google_api_spec(self) -> None:
|
||||
"""Fetches the Google API specification using discovery service."""
|
||||
try:
|
||||
logger.info(
|
||||
"Fetching Google API spec for %s %s", self.api_name, self.api_version
|
||||
)
|
||||
# Build a resource object for the specified API
|
||||
self.google_api_resource = build(self.api_name, self.api_version)
|
||||
|
||||
# Access the underlying API discovery document
|
||||
self.google_api_spec = self.google_api_resource._rootDesc
|
||||
|
||||
if not self.google_api_spec:
|
||||
raise ValueError("Failed to retrieve API specification")
|
||||
|
||||
logger.info("Successfully fetched %s API specification", self.api_name)
|
||||
except HttpError as e:
|
||||
logger.error("HTTP Error: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error fetching API spec: %s", e)
|
||||
raise
|
||||
|
||||
def convert(self) -> Dict[str, Any]:
|
||||
"""Convert the Google API spec to OpenAPI v3 format.
|
||||
|
||||
Returns:
|
||||
Dict containing the converted OpenAPI v3 specification
|
||||
"""
|
||||
if not self.google_api_spec:
|
||||
self.fetch_google_api_spec()
|
||||
|
||||
# Convert basic API information
|
||||
self._convert_info()
|
||||
|
||||
# Convert server information
|
||||
self._convert_servers()
|
||||
|
||||
# Convert authentication/authorization schemes
|
||||
self._convert_security_schemes()
|
||||
|
||||
# Convert schemas (models)
|
||||
self._convert_schemas()
|
||||
|
||||
# Convert endpoints/paths
|
||||
self._convert_resources(self.google_api_spec.get("resources", {}))
|
||||
|
||||
# Convert top-level methods, if any
|
||||
self._convert_methods(self.google_api_spec.get("methods", {}), "/")
|
||||
|
||||
return self.openapi_spec
|
||||
|
||||
def _convert_info(self) -> None:
|
||||
"""Convert basic API information."""
|
||||
self.openapi_spec["info"] = {
|
||||
"title": self.google_api_spec.get("title", f"{self.api_name} API"),
|
||||
"description": self.google_api_spec.get("description", ""),
|
||||
"version": self.google_api_spec.get("version", self.api_version),
|
||||
"contact": {},
|
||||
"termsOfService": self.google_api_spec.get("documentationLink", ""),
|
||||
}
|
||||
|
||||
# Add documentation links if available
|
||||
docs_link = self.google_api_spec.get("documentationLink")
|
||||
if docs_link:
|
||||
self.openapi_spec["externalDocs"] = {
|
||||
"description": "API Documentation",
|
||||
"url": docs_link,
|
||||
}
|
||||
|
||||
def _convert_servers(self) -> None:
|
||||
"""Convert server information."""
|
||||
base_url = self.google_api_spec.get(
|
||||
"rootUrl", ""
|
||||
) + self.google_api_spec.get("servicePath", "")
|
||||
|
||||
# Remove trailing slash if present
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
self.openapi_spec["servers"] = [{
|
||||
"url": base_url,
|
||||
"description": f"{self.api_name} {self.api_version} API",
|
||||
}]
|
||||
|
||||
def _convert_security_schemes(self) -> None:
|
||||
"""Convert authentication and authorization schemes."""
|
||||
auth = self.google_api_spec.get("auth", {})
|
||||
oauth2 = auth.get("oauth2", {})
|
||||
|
||||
if oauth2:
|
||||
# Handle OAuth2
|
||||
scopes = oauth2.get("scopes", {})
|
||||
formatted_scopes = {}
|
||||
|
||||
for scope, scope_info in scopes.items():
|
||||
formatted_scopes[scope] = scope_info.get("description", "")
|
||||
|
||||
self.openapi_spec["components"]["securitySchemes"]["oauth2"] = {
|
||||
"type": "oauth2",
|
||||
"description": "OAuth 2.0 authentication",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": (
|
||||
"https://accounts.google.com/o/oauth2/auth"
|
||||
),
|
||||
"tokenUrl": "https://oauth2.googleapis.com/token",
|
||||
"scopes": formatted_scopes,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Add API key authentication (most Google APIs support this)
|
||||
self.openapi_spec["components"]["securitySchemes"]["apiKey"] = {
|
||||
"type": "apiKey",
|
||||
"in": "query",
|
||||
"name": "key",
|
||||
"description": "API key for accessing this API",
|
||||
}
|
||||
|
||||
# Create global security requirement
|
||||
self.openapi_spec["security"] = [
|
||||
{"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
|
||||
{"apiKey": []},
|
||||
]
|
||||
|
||||
def _convert_schemas(self) -> None:
|
||||
"""Convert schema definitions (models)."""
|
||||
schemas = self.google_api_spec.get("schemas", {})
|
||||
|
||||
for schema_name, schema_def in schemas.items():
|
||||
converted_schema = self._convert_schema_object(schema_def)
|
||||
self.openapi_spec["components"]["schemas"][schema_name] = converted_schema
|
||||
|
||||
def _convert_schema_object(
|
||||
self, schema_def: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Recursively convert a Google API schema object to OpenAPI schema.
|
||||
|
||||
Args:
|
||||
schema_def: Google API schema definition
|
||||
|
||||
Returns:
|
||||
Converted OpenAPI schema object
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# Convert the type
|
||||
if "type" in schema_def:
|
||||
gtype = schema_def["type"]
|
||||
if gtype == "object":
|
||||
result["type"] = "object"
|
||||
|
||||
# Handle properties
|
||||
if "properties" in schema_def:
|
||||
result["properties"] = {}
|
||||
for prop_name, prop_def in schema_def["properties"].items():
|
||||
result["properties"][prop_name] = self._convert_schema_object(
|
||||
prop_def
|
||||
)
|
||||
|
||||
# Handle required fields
|
||||
required_fields = []
|
||||
for prop_name, prop_def in schema_def.get("properties", {}).items():
|
||||
if prop_def.get("required", False):
|
||||
required_fields.append(prop_name)
|
||||
if required_fields:
|
||||
result["required"] = required_fields
|
||||
|
||||
elif gtype == "array":
|
||||
result["type"] = "array"
|
||||
if "items" in schema_def:
|
||||
result["items"] = self._convert_schema_object(schema_def["items"])
|
||||
|
||||
elif gtype == "any":
|
||||
# OpenAPI doesn't have direct "any" type
|
||||
# Use oneOf with multiple options as alternative
|
||||
result["oneOf"] = [
|
||||
{"type": "object"},
|
||||
{"type": "array"},
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "boolean"},
|
||||
{"type": "null"},
|
||||
]
|
||||
|
||||
else:
|
||||
# Handle other primitive types
|
||||
result["type"] = gtype
|
||||
|
||||
# Handle references
|
||||
if "$ref" in schema_def:
|
||||
ref = schema_def["$ref"]
|
||||
# Google refs use "#" at start, OpenAPI uses "#/components/schemas/"
|
||||
if ref.startswith("#"):
|
||||
ref = ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
ref = "#/components/schemas/" + ref
|
||||
result["$ref"] = ref
|
||||
|
||||
# Handle format
|
||||
if "format" in schema_def:
|
||||
result["format"] = schema_def["format"]
|
||||
|
||||
# Handle enum values
|
||||
if "enum" in schema_def:
|
||||
result["enum"] = schema_def["enum"]
|
||||
|
||||
# Handle description
|
||||
if "description" in schema_def:
|
||||
result["description"] = schema_def["description"]
|
||||
|
||||
# Handle pattern
|
||||
if "pattern" in schema_def:
|
||||
result["pattern"] = schema_def["pattern"]
|
||||
|
||||
# Handle default value
|
||||
if "default" in schema_def:
|
||||
result["default"] = schema_def["default"]
|
||||
|
||||
return result
|
||||
|
||||
def _convert_resources(
|
||||
self, resources: Dict[str, Any], parent_path: str = ""
|
||||
) -> None:
|
||||
"""Recursively convert all resources and their methods.
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources from the Google API spec
|
||||
parent_path: The parent path prefix for nested resources
|
||||
"""
|
||||
for resource_name, resource_data in resources.items():
|
||||
# Process methods for this resource
|
||||
resource_path = f"{parent_path}/{resource_name}"
|
||||
methods = resource_data.get("methods", {})
|
||||
self._convert_methods(methods, resource_path)
|
||||
|
||||
# Process nested resources recursively
|
||||
nested_resources = resource_data.get("resources", {})
|
||||
if nested_resources:
|
||||
self._convert_resources(nested_resources, resource_path)
|
||||
|
||||
def _convert_methods(
|
||||
self, methods: Dict[str, Any], resource_path: str
|
||||
) -> None:
|
||||
"""Convert methods for a specific resource path.
|
||||
|
||||
Args:
|
||||
methods: Dictionary of methods from the Google API spec
|
||||
resource_path: The path of the resource these methods belong to
|
||||
"""
|
||||
for method_name, method_data in methods.items():
|
||||
http_method = method_data.get("httpMethod", "GET").lower()
|
||||
|
||||
# Determine the actual endpoint path
|
||||
# Google often has the format something like 'users.messages.list'
|
||||
rest_path = method_data.get("path", "/")
|
||||
if not rest_path.startswith("/"):
|
||||
rest_path = "/" + rest_path
|
||||
|
||||
path_params = self._extract_path_parameters(rest_path)
|
||||
|
||||
# Create path entry if it doesn't exist
|
||||
if rest_path not in self.openapi_spec["paths"]:
|
||||
self.openapi_spec["paths"][rest_path] = {}
|
||||
|
||||
# Add the operation for this method
|
||||
self.openapi_spec["paths"][rest_path][http_method] = (
|
||||
self._convert_operation(method_data, path_params)
|
||||
)
|
||||
|
||||
def _extract_path_parameters(self, path: str) -> List[str]:
|
||||
"""Extract path parameters from a URL path.
|
||||
|
||||
Args:
|
||||
path: The URL path with path parameters
|
||||
|
||||
Returns:
|
||||
List of parameter names
|
||||
"""
|
||||
params = []
|
||||
segments = path.split("/")
|
||||
|
||||
for segment in segments:
|
||||
# Google APIs often use {param} format for path parameters
|
||||
if segment.startswith("{") and segment.endswith("}"):
|
||||
param_name = segment[1:-1]
|
||||
params.append(param_name)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_operation(
|
||||
self, method_data: Dict[str, Any], path_params: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a Google API method to an OpenAPI operation.
|
||||
|
||||
Args:
|
||||
method_data: Google API method data
|
||||
path_params: List of path parameter names
|
||||
|
||||
Returns:
|
||||
OpenAPI operation object
|
||||
"""
|
||||
operation = {
|
||||
"operationId": method_data.get("id", ""),
|
||||
"summary": method_data.get("description", ""),
|
||||
"description": method_data.get("description", ""),
|
||||
"parameters": [],
|
||||
"responses": {
|
||||
"200": {"description": "Successful operation"},
|
||||
"400": {"description": "Bad request"},
|
||||
"401": {"description": "Unauthorized"},
|
||||
"403": {"description": "Forbidden"},
|
||||
"404": {"description": "Not found"},
|
||||
"500": {"description": "Server error"},
|
||||
},
|
||||
}
|
||||
|
||||
# Add path parameters
|
||||
for param_name in path_params:
|
||||
param = {
|
||||
"name": param_name,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
operation["parameters"].append(param)
|
||||
|
||||
# Add query parameters
|
||||
for param_name, param_data in method_data.get("parameters", {}).items():
|
||||
# Skip parameters already included in path
|
||||
if param_name in path_params:
|
||||
continue
|
||||
|
||||
param = {
|
||||
"name": param_name,
|
||||
"in": "query",
|
||||
"description": param_data.get("description", ""),
|
||||
"required": param_data.get("required", False),
|
||||
"schema": self._convert_parameter_schema(param_data),
|
||||
}
|
||||
operation["parameters"].append(param)
|
||||
|
||||
# Handle request body
|
||||
if "request" in method_data:
|
||||
request_ref = method_data.get("request", {}).get("$ref", "")
|
||||
if request_ref:
|
||||
if request_ref.startswith("#"):
|
||||
# Convert Google's reference format to OpenAPI format
|
||||
openapi_ref = request_ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
openapi_ref = "#/components/schemas/" + request_ref
|
||||
operation["requestBody"] = {
|
||||
"description": "Request body",
|
||||
"content": {"application/json": {"schema": {"$ref": openapi_ref}}},
|
||||
"required": True,
|
||||
}
|
||||
|
||||
# Handle response body
|
||||
if "response" in method_data:
|
||||
response_ref = method_data.get("response", {}).get("$ref", "")
|
||||
if response_ref:
|
||||
if response_ref.startswith("#"):
|
||||
# Convert Google's reference format to OpenAPI format
|
||||
openapi_ref = response_ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
openapi_ref = "#/components/schemas/" + response_ref
|
||||
operation["responses"]["200"]["content"] = {
|
||||
"application/json": {"schema": {"$ref": openapi_ref}}
|
||||
}
|
||||
|
||||
# Add scopes if available
|
||||
scopes = method_data.get("scopes", [])
|
||||
if scopes:
|
||||
# Add method-specific security requirement if different from global
|
||||
operation["security"] = [{"oauth2": scopes}]
|
||||
|
||||
return operation
|
||||
|
||||
def _convert_parameter_schema(
|
||||
self, param_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a parameter definition to an OpenAPI schema.
|
||||
|
||||
Args:
|
||||
param_data: Google API parameter data
|
||||
|
||||
Returns:
|
||||
OpenAPI schema for the parameter
|
||||
"""
|
||||
schema = {}
|
||||
|
||||
# Convert type
|
||||
param_type = param_data.get("type", "string")
|
||||
schema["type"] = param_type
|
||||
|
||||
# Handle enum values
|
||||
if "enum" in param_data:
|
||||
schema["enum"] = param_data["enum"]
|
||||
|
||||
# Handle format
|
||||
if "format" in param_data:
|
||||
schema["format"] = param_data["format"]
|
||||
|
||||
# Handle default value
|
||||
if "default" in param_data:
|
||||
schema["default"] = param_data["default"]
|
||||
|
||||
# Handle pattern
|
||||
if "pattern" in param_data:
|
||||
schema["pattern"] = param_data["pattern"]
|
||||
|
||||
return schema
|
||||
|
||||
def save_openapi_spec(self, output_path: str) -> None:
|
||||
"""Save the OpenAPI specification to a file.
|
||||
|
||||
Args:
|
||||
output_path: Path where the OpenAPI spec should be saved
|
||||
"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.openapi_spec, f, indent=2)
|
||||
logger.info("OpenAPI specification saved to %s", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
"""Command line interface for the converter."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Convert Google API Discovery documents to OpenAPI v3 specifications"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"api_name", help="Name of the Google API (e.g., 'calendar')"
|
||||
)
|
||||
parser.add_argument("api_version", help="Version of the API (e.g., 'v3')")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="openapi_spec.json",
|
||||
help="Output file path for the OpenAPI specification",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Create and run the converter
|
||||
converter = GoogleApiToOpenApiConverter(args.api_name, args.api_version)
|
||||
converter.convert()
|
||||
converter.save_openapi_spec(args.output)
|
||||
print(
|
||||
f"Successfully converted {args.api_name} {args.api_version} to"
|
||||
" OpenAPI v3"
|
||||
)
|
||||
print(f"Output saved to {args.output}")
|
||||
except Exception as e:
|
||||
logger.error("Conversion failed: %s", e)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
68
src/google/adk/tools/google_search_tool.py
Normal file
68
src/google/adk/tools/google_search_tool.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class GoogleSearchTool(BaseTool):
|
||||
"""A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from Google Search.
|
||||
|
||||
This tool operates internally within the model and does not require or perform
|
||||
local code execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Name and description are not used because this is a model built-in tool.
|
||||
super().__init__(name='google_search', description='google_search')
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.tools = llm_request.config.tools or []
|
||||
if llm_request.model and llm_request.model.startswith('gemini-1'):
|
||||
if llm_request.config.tools:
|
||||
print(llm_request.config.tools)
|
||||
raise ValueError(
|
||||
'Google search tool can not be used with other tools in Gemini 1.x.'
|
||||
)
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
|
||||
)
|
||||
elif llm_request.model and llm_request.model.startswith('gemini-2'):
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(google_search=types.GoogleSearch())
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Google search tool is not supported for model {llm_request.model}'
|
||||
)
|
||||
|
||||
|
||||
google_search = GoogleSearchTool()
|
||||
86
src/google/adk/tools/langchain_tool.py
Normal file
86
src/google/adk/tools/langchain_tool.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
|
||||
class LangchainTool(FunctionTool):
|
||||
"""Use this class to wrap a langchain tool.
|
||||
|
||||
If the original tool name and description are not suitable, you can override
|
||||
them in the constructor.
|
||||
"""
|
||||
|
||||
tool: Any
|
||||
"""The wrapped langchain tool."""
|
||||
|
||||
def __init__(self, tool: Any):
|
||||
super().__init__(tool._run)
|
||||
self.tool = tool
|
||||
if tool.name:
|
||||
self.name = tool.name
|
||||
if tool.description:
|
||||
self.description = tool.description
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def populate_name(cls, data: Any) -> Any:
|
||||
# Override this to not use function's signature name as it's
|
||||
# mostly "run" or "invoke" for thir-party tools.
|
||||
return data
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
from langchain.agents import Tool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
# There are two types of tools:
|
||||
# 1. BaseTool: the tool is defined in langchain.tools.
|
||||
# 2. Other tools: the tool doesn't inherit any class but follow some
|
||||
# conventions, like having a "run" method.
|
||||
if isinstance(self.tool, BaseTool):
|
||||
tool_wrapper = Tool(
|
||||
name=self.name,
|
||||
func=self.func,
|
||||
description=self.description,
|
||||
)
|
||||
if self.tool.args_schema:
|
||||
tool_wrapper.args_schema = self.tool.args_schema
|
||||
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
|
||||
False,
|
||||
self.name,
|
||||
self.description,
|
||||
tool_wrapper.func,
|
||||
tool_wrapper.args,
|
||||
)
|
||||
return function_declaration
|
||||
else:
|
||||
# Need to provide a way to override the function names and descriptions
|
||||
# as the original function names are mostly ".run" and the descriptions
|
||||
# may not meet users' needs.
|
||||
function_declaration = (
|
||||
_automatic_function_calling_util.build_function_declaration(
|
||||
func=self.tool.run,
|
||||
)
|
||||
)
|
||||
return function_declaration
|
||||
113
src/google/adk/tools/load_artifacts_tool.py
Normal file
113
src/google/adk/tools/load_artifacts_tool.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
class LoadArtifactsTool(BaseTool):
|
||||
"""A tool that loads the artifacts and adds them to the session."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name='load_artifacts',
|
||||
description='Loads the artifacts and adds them to the session.',
|
||||
)
|
||||
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'artifact_names': types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
artifact_names: list[str] = args.get('artifact_names', [])
|
||||
return {'artifact_names': artifact_names}
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context,
|
||||
llm_request=llm_request,
|
||||
)
|
||||
self._append_artifacts_to_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
def _append_artifacts_to_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
):
|
||||
artifact_names = tool_context.list_artifacts()
|
||||
if not artifact_names:
|
||||
return
|
||||
|
||||
# Tell the model about the available artifacts.
|
||||
llm_request.append_instructions([f"""You have a list of artifacts:
|
||||
{json.dumps(artifact_names)}
|
||||
|
||||
When the user asks questions about any of the artifacts, you should call the
|
||||
`load_artifacts` function to load the artifact. Do not generate any text other
|
||||
than the function call.
|
||||
"""])
|
||||
|
||||
# Attache the content of the artifacts if the model requests them.
|
||||
# This only adds the content to the model request, instead of the session.
|
||||
if llm_request.contents and llm_request.contents[-1].parts:
|
||||
function_response = llm_request.contents[-1].parts[0].function_response
|
||||
if function_response and function_response.name == 'load_artifacts':
|
||||
artifact_names = function_response.response['artifact_names']
|
||||
for artifact_name in artifact_names:
|
||||
artifact = tool_context.load_artifact(artifact_name)
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Artifact {artifact_name} is:'
|
||||
),
|
||||
artifact,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
load_artifacts_tool = LoadArtifactsTool()
|
||||
58
src/google/adk/tools/load_memory_tool.py
Normal file
58
src/google/adk/tools/load_memory_tool.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .function_tool import FunctionTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
from ..memory.base_memory_service import MemoryResult
|
||||
|
||||
|
||||
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
||||
"""Loads the memory for the current user."""
|
||||
response = tool_context.search_memory(query)
|
||||
return response.memories
|
||||
|
||||
|
||||
class LoadMemoryTool(FunctionTool):
|
||||
"""A tool that loads the memory for the current user."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(load_memory)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
# Tell the model about the memory.
|
||||
llm_request.append_instructions(["""
|
||||
You have memory. You can use it to answer questions. If any questions need
|
||||
you to look up the memory, you should call load_memory function with a query.
|
||||
"""])
|
||||
|
||||
|
||||
load_memory_tool = LoadMemoryTool()
|
||||
41
src/google/adk/tools/load_web_page.py
Normal file
41
src/google/adk/tools/load_web_page.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tool for web browse."""
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def load_web_page(url: str) -> str:
|
||||
"""Fetches the content in the url and returns the text in it.
|
||||
|
||||
Args:
|
||||
url (str): The url to browse.
|
||||
|
||||
Returns:
|
||||
str: The text content of the url.
|
||||
"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, 'lxml')
|
||||
text = soup.get_text(separator='\n', strip=True)
|
||||
else:
|
||||
text = f'Failed to fetch url: {url}'
|
||||
|
||||
# Split the text into lines, filtering out very short lines
|
||||
# (e.g., single words or short subtitles)
|
||||
return '\n'.join(line for line in text.splitlines() if len(line.split()) > 3)
|
||||
39
src/google/adk/tools/long_running_tool.py
Normal file
39
src/google/adk/tools/long_running_tool.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
|
||||
class LongRunningFunctionTool(FunctionTool):
|
||||
"""A function tool that returns the result asynchronously.
|
||||
|
||||
This tool is used for long-running operations that may take a significant
|
||||
amount of time to complete. The framework will call the function. Once the
|
||||
function returns, the response will be returned asynchronously to the
|
||||
framework which is identified by the function_call_id.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tool = LongRunningFunctionTool(a_long_running_function)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
is_long_running: Whether the tool is a long running operation.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
super().__init__(func)
|
||||
self.is_long_running = True
|
||||
42
src/google/adk/tools/mcp_tool/__init__.py
Normal file
42
src/google/adk/tools/mcp_tool/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = []
|
||||
|
||||
try:
|
||||
from .conversion_utils import adk_to_mcp_tool_type, gemini_to_json_schema
|
||||
from .mcp_tool import MCPTool
|
||||
from .mcp_toolset import MCPToolset
|
||||
|
||||
__all__.extend([
|
||||
'adk_to_mcp_tool_type',
|
||||
'gemini_to_json_schema',
|
||||
'MCPTool',
|
||||
'MCPToolset',
|
||||
])
|
||||
|
||||
except ImportError as e:
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
logger.warning(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
)
|
||||
else:
|
||||
logger.debug('MCP Tool is not installed')
|
||||
logger.debug(e)
|
||||
161
src/google/adk/tools/mcp_tool/conversion_utils.py
Normal file
161
src/google/adk/tools/mcp_tool/conversion_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
from google.genai.types import Schema, Type
|
||||
import mcp.types as mcp_types
|
||||
from ..base_tool import BaseTool
|
||||
|
||||
|
||||
def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
|
||||
"""Convert a Tool in ADK into MCP tool type.
|
||||
|
||||
This function transforms an ADK tool definition into its equivalent
|
||||
representation in the MCP (Model Control Plane) system.
|
||||
|
||||
Args:
|
||||
tool: The ADK tool to convert. It should be an instance of a class derived
|
||||
from `BaseTool`.
|
||||
|
||||
Returns:
|
||||
An object of MCP Tool type, representing the converted tool.
|
||||
|
||||
Examples:
|
||||
# Assuming 'my_tool' is an instance of a BaseTool derived class
|
||||
mcp_tool = adk_to_mcp_tool_type(my_tool)
|
||||
print(mcp_tool)
|
||||
"""
|
||||
tool_declaration = tool._get_declaration()
|
||||
if not tool_declaration:
|
||||
input_schema = {}
|
||||
else:
|
||||
input_schema = gemini_to_json_schema(tool._get_declaration().parameters)
|
||||
return mcp_types.Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=input_schema,
|
||||
)
|
||||
|
||||
|
||||
def gemini_to_json_schema(gemini_schema: Schema) -> Dict[str, Any]:
|
||||
"""Converts a Gemini Schema object into a JSON Schema dictionary.
|
||||
|
||||
Args:
|
||||
gemini_schema: An instance of the Gemini Schema class.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the equivalent JSON Schema.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input is not an instance of the expected Schema class.
|
||||
ValueError: If an invalid Gemini Type enum value is encountered.
|
||||
"""
|
||||
if not isinstance(gemini_schema, Schema):
|
||||
raise TypeError(
|
||||
f"Input must be an instance of Schema, got {type(gemini_schema)}"
|
||||
)
|
||||
|
||||
json_schema_dict: Dict[str, Any] = {}
|
||||
|
||||
# Map Type
|
||||
gemini_type = getattr(gemini_schema, "type", None)
|
||||
if gemini_type and gemini_type != Type.TYPE_UNSPECIFIED:
|
||||
json_schema_dict["type"] = gemini_type.lower()
|
||||
else:
|
||||
json_schema_dict["type"] = "null"
|
||||
|
||||
# Map Nullable
|
||||
if getattr(gemini_schema, "nullable", None) == True:
|
||||
json_schema_dict["nullable"] = True
|
||||
|
||||
# --- Map direct fields ---
|
||||
direct_mappings = {
|
||||
"title": "title",
|
||||
"description": "description",
|
||||
"default": "default",
|
||||
"enum": "enum",
|
||||
"format": "format",
|
||||
"example": "example",
|
||||
}
|
||||
for gemini_key, json_key in direct_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# String validation
|
||||
if gemini_type == Type.STRING:
|
||||
str_mappings = {
|
||||
"pattern": "pattern",
|
||||
"min_length": "minLength",
|
||||
"max_length": "maxLength",
|
||||
}
|
||||
for gemini_key, json_key in str_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Number/Integer validation
|
||||
if gemini_type in (Type.NUMBER, Type.INTEGER):
|
||||
num_mappings = {
|
||||
"minimum": "minimum",
|
||||
"maximum": "maximum",
|
||||
}
|
||||
for gemini_key, json_key in num_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Array validation (Recursive call for items)
|
||||
if gemini_type == Type.ARRAY:
|
||||
items_schema = getattr(gemini_schema, "items", None)
|
||||
if items_schema is not None:
|
||||
json_schema_dict["items"] = gemini_to_json_schema(items_schema)
|
||||
|
||||
arr_mappings = {
|
||||
"min_items": "minItems",
|
||||
"max_items": "maxItems",
|
||||
}
|
||||
for gemini_key, json_key in arr_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Object validation (Recursive call for properties)
|
||||
if gemini_type == Type.OBJECT:
|
||||
properties_dict = getattr(gemini_schema, "properties", None)
|
||||
if properties_dict is not None:
|
||||
json_schema_dict["properties"] = {
|
||||
prop_name: gemini_to_json_schema(prop_schema)
|
||||
for prop_name, prop_schema in properties_dict.items()
|
||||
}
|
||||
|
||||
obj_mappings = {
|
||||
"required": "required",
|
||||
"min_properties": "minProperties",
|
||||
"max_properties": "maxProperties",
|
||||
# Note: Ignoring 'property_ordering' as it's not standard JSON Schema
|
||||
}
|
||||
for gemini_key, json_key in obj_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Map anyOf (Recursive call for subschemas)
|
||||
any_of_list = getattr(gemini_schema, "any_of", None)
|
||||
if any_of_list is not None:
|
||||
json_schema_dict["anyOf"] = [
|
||||
gemini_to_json_schema(sub_schema) for sub_schema in any_of_list
|
||||
]
|
||||
|
||||
return json_schema_dict
|
||||
113
src/google/adk/tools/mcp_tool/mcp_tool.py
Normal file
113
src/google/adk/tools/mcp_tool/mcp_tool.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.types import Tool as McpBaseTool
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
|
||||
" version."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
from ..base_tool import BaseTool
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
|
||||
|
||||
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
|
||||
call the tool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_tool: McpBaseTool,
|
||||
mcp_session: ClientSession,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] | None = None,
|
||||
):
|
||||
"""Initializes a MCPTool.
|
||||
|
||||
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
|
||||
the MCP Tool through executing the tool from remote MCP Session.
|
||||
|
||||
Example:
|
||||
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
|
||||
|
||||
Args:
|
||||
mcp_tool: The MCP tool to wrap.
|
||||
mcp_session: The MCP session to use to call the tool.
|
||||
auth_scheme: The authentication scheme to use.
|
||||
auth_credential: The authentication credential to use.
|
||||
|
||||
Raises:
|
||||
ValueError: If mcp_tool or mcp_session is None.
|
||||
"""
|
||||
if mcp_tool is None:
|
||||
raise ValueError("mcp_tool cannot be None")
|
||||
if mcp_session is None:
|
||||
raise ValueError("mcp_session cannot be None")
|
||||
self.name = mcp_tool.name
|
||||
self.description = mcp_tool.description if mcp_tool.description else ""
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_session = mcp_session
|
||||
# TODO(cheliu): Support passing auth to MCP Server.
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Gets the function declaration for the tool.
|
||||
|
||||
Returns:
|
||||
FunctionDeclaration: The Gemini function declaration for the tool.
|
||||
"""
|
||||
schema_dict = self.mcp_tool.inputSchema
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(self, *, args, tool_context: ToolContext):
|
||||
"""Runs the tool asynchronously.
|
||||
|
||||
Args:
|
||||
args: The arguments as a dict to pass to the tool.
|
||||
tool_context: The tool context from upper level ADK agent.
|
||||
|
||||
Returns:
|
||||
Any: The response from the tool.
|
||||
"""
|
||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
272
src/google/adk/tools/mcp_tool/mcp_toolset.py
Normal file
272
src/google/adk/tools/mcp_tool/mcp_toolset.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import ListToolsResult
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .mcp_tool import MCPTool
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class MCPToolset:
|
||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||
|
||||
Usage:
|
||||
Example 1: (using from_server helper):
|
||||
```
|
||||
async def load_tools():
|
||||
return await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
|
||||
# Use the tools in an LLM agent
|
||||
tools, exit_stack = await load_tools()
|
||||
agent = LlmAgent(
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await exit_stack.aclose()
|
||||
```
|
||||
|
||||
Example 2: (using `async with`):
|
||||
|
||||
```
|
||||
async def load_tools():
|
||||
async with MCPToolset(
|
||||
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
|
||||
) as toolset:
|
||||
tools = await toolset.load_tools()
|
||||
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
Example 3: (provide AsyncExitStack):
|
||||
```
|
||||
async def load_tools():
|
||||
async_exit_stack = AsyncExitStack()
|
||||
toolset = MCPToolset(
|
||||
connection_params=StdioServerParameters(...),
|
||||
)
|
||||
async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await async_exit_stack.aclose()
|
||||
|
||||
```
|
||||
|
||||
Attributes:
|
||||
connection_params: The connection parameters to the MCP server. Can be
|
||||
either `StdioServerParameters` or `SseServerParams`.
|
||||
exit_stack: The async exit stack to manage the connection to the MCP server.
|
||||
session: The MCP session being initialized with the connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, connection_params: StdioServerParameters | SseServerParams
|
||||
):
|
||||
"""Initializes the MCPToolset.
|
||||
|
||||
Usage:
|
||||
Example 1: (using from_server helper):
|
||||
```
|
||||
async def load_tools():
|
||||
return await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
|
||||
# Use the tools in an LLM agent
|
||||
tools, exit_stack = await load_tools()
|
||||
agent = LlmAgent(
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await exit_stack.aclose()
|
||||
```
|
||||
|
||||
Example 2: (using `async with`):
|
||||
|
||||
```
|
||||
async def load_tools():
|
||||
async with MCPToolset(
|
||||
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
|
||||
) as toolset:
|
||||
tools = await toolset.load_tools()
|
||||
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
Example 3: (provide AsyncExitStack):
|
||||
```
|
||||
async def load_tools():
|
||||
async_exit_stack = AsyncExitStack()
|
||||
toolset = MCPToolset(
|
||||
connection_params=StdioServerParameters(...),
|
||||
)
|
||||
async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await async_exit_stack.aclose()
|
||||
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server. Can be:
|
||||
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
|
||||
`python3`); or `SseServerParams` for a local/remote SSE server.
|
||||
"""
|
||||
if not connection_params:
|
||||
raise ValueError('Missing connection params in MCPToolset.')
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
@classmethod
|
||||
async def from_server(
|
||||
cls,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
async_exit_stack: Optional[AsyncExitStack] = None,
|
||||
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
||||
"""Retrieve all tools from the MCP connection.
|
||||
|
||||
Usage:
|
||||
```
|
||||
async def load_tools():
|
||||
tools, exit_stack = await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server.
|
||||
async_exit_stack: The async exit stack to use. If not provided, a new
|
||||
AsyncExitStack will be created.
|
||||
|
||||
Returns:
|
||||
A tuple of the list of MCPTools and the AsyncExitStack.
|
||||
- tools: The list of MCPTools.
|
||||
- async_exit_stack: The AsyncExitStack used to manage the connection to
|
||||
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
||||
connection when server shuts down.
|
||||
"""
|
||||
toolset = cls(connection_params=connection_params)
|
||||
async_exit_stack = async_exit_stack or AsyncExitStack()
|
||||
await async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
return (tools, async_exit_stack)
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||
if isinstance(self.connection_params, StdioServerParameters):
|
||||
client = stdio_client(self.connection_params)
|
||||
elif isinstance(self.connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=self.connection_params.url,
|
||||
headers=self.connection_params.headers,
|
||||
timeout=self.connection_params.timeout,
|
||||
sse_read_timeout=self.connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {self.connection_params}'
|
||||
)
|
||||
|
||||
transports = await self.exit_stack.enter_async_context(client)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*transports)
|
||||
)
|
||||
await self.session.initialize()
|
||||
return self.session
|
||||
|
||||
async def _exit(self):
|
||||
"""Closes the connection to MCP Server."""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
async def load_tools(self) -> List[MCPTool]:
|
||||
"""Loads all tools from the MCP Server.
|
||||
|
||||
Returns:
|
||||
A list of MCPTools imported from the MCP Server.
|
||||
"""
|
||||
tools_response: ListToolsResult = await self.session.list_tools()
|
||||
return [
|
||||
MCPTool(mcp_tool=tool, mcp_session=self.session)
|
||||
for tool in tools_response.tools
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
try:
|
||||
await self._initialize()
|
||||
return self
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self._exit()
|
||||
21
src/google/adk/tools/openapi_tool/__init__.py
Normal file
21
src/google/adk/tools/openapi_tool/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenAPIToolset
|
||||
from .openapi_spec_parser import RestApiTool
|
||||
|
||||
__all__ = [
|
||||
'OpenAPIToolset',
|
||||
'RestApiTool',
|
||||
]
|
||||
19
src/google/adk/tools/openapi_tool/auth/__init__.py
Normal file
19
src/google/adk/tools/openapi_tool/auth/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import auth_helpers
|
||||
|
||||
__all__ = [
|
||||
'auth_helpers',
|
||||
]
|
||||
498
src/google/adk/tools/openapi_tool/auth/auth_helpers.py
Normal file
498
src/google/adk/tools/openapi_tool/auth/auth_helpers.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
import requests
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_credential import HttpAuth
|
||||
from ....auth.auth_credential import HttpCredentials
|
||||
from ....auth.auth_credential import OAuth2Auth
|
||||
from ....auth.auth_credential import ServiceAccount
|
||||
from ....auth.auth_credential import ServiceAccountCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from ..common.common import ApiParameter
|
||||
|
||||
|
||||
class OpenIdConfig(BaseModel):
|
||||
"""Represents OpenID Connect configuration.
|
||||
|
||||
Attributes:
|
||||
client_id: The client ID.
|
||||
auth_uri: The authorization URI.
|
||||
token_uri: The token URI.
|
||||
client_secret: The client secret.
|
||||
|
||||
Example:
|
||||
config = OpenIdConfig(
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_secret="your_client_secret",
|
||||
redirect
|
||||
)
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
auth_uri: str
|
||||
token_uri: str
|
||||
client_secret: str
|
||||
redirect_uri: Optional[str]
|
||||
|
||||
|
||||
def token_to_scheme_credential(
|
||||
token_type: Literal["apikey", "oauth2Token"],
|
||||
location: Optional[Literal["header", "query", "cookie"]] = None,
|
||||
name: Optional[str] = None,
|
||||
credential_value: Optional[str] = None,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates a AuthScheme and AuthCredential for API key or bearer token.
|
||||
|
||||
Examples:
|
||||
```
|
||||
# API Key in header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
|
||||
"X-API-Key", "your_api_key_value")
|
||||
|
||||
# API Key in query parameter
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
|
||||
"api_key", "your_api_key_value")
|
||||
|
||||
# OAuth2 Bearer Token in Authorization header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
|
||||
"header", "Authorization", "your_bearer_token_value")
|
||||
```
|
||||
|
||||
Args:
|
||||
type: 'apikey' or 'oauth2Token'.
|
||||
location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
|
||||
name: The name of the header, query parameter, or cookie.
|
||||
credential_value: The value of the API Key/ Token.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid type or location.
|
||||
"""
|
||||
if token_type == "apikey":
|
||||
in_: APIKeyIn
|
||||
if location == "header":
|
||||
in_ = APIKeyIn.header
|
||||
elif location == "query":
|
||||
in_ = APIKeyIn.query
|
||||
elif location == "cookie":
|
||||
in_ = APIKeyIn.cookie
|
||||
else:
|
||||
raise ValueError(f"Invalid location for apiKey: {location}")
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": in_,
|
||||
"name": name,
|
||||
})
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
elif token_type == "oauth2Token":
|
||||
# ignore location. OAuth2 Bearer Token is always in Authorization header.
|
||||
auth_scheme = HTTPBearer(
|
||||
bearerFormat="JWT"
|
||||
) # Common format, can be omitted.
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credential_value),
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {type}")
|
||||
|
||||
|
||||
def service_account_dict_to_scheme_credential(
|
||||
config: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
scopes: A list of scopes to be used.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
service_account = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential.model_construct(
|
||||
**config
|
||||
),
|
||||
scopes=scopes,
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def service_account_scheme_credential(
|
||||
config: ServiceAccount,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_dict_to_scheme_credential(
|
||||
config_dict: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
credential_dict: Dict[str, Any],
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from configuration and credential dictionaries.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary containing OpenID Connect configuration, must
|
||||
include at least 'authorization_endpoint' and 'token_endpoint'.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include 'client_id', 'client_secret', and 'scopes'. May optionally
|
||||
include 'redirect_uri'.
|
||||
|
||||
Returns:
|
||||
Tuple: (OpenIdConnectWithConfig, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing in the input dictionaries.
|
||||
"""
|
||||
|
||||
# Validate and create the OpenIdConnectWithConfig scheme
|
||||
try:
|
||||
config_dict["scopes"] = scopes
|
||||
# If user provides the OpenID Config as a static dict, it may not contain
|
||||
# openIdConnect URL.
|
||||
if "openIdConnectUrl" not in config_dict:
|
||||
config_dict["openIdConnectUrl"] = ""
|
||||
openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
|
||||
|
||||
# Attempt to adjust credential_dict if this is a key downloaded from Google
|
||||
# OAuth config
|
||||
if len(list(credential_dict.values())) == 1:
|
||||
credential_value = list(credential_dict.values())[0]
|
||||
if "client_id" in credential_value and "client_secret" in credential_value:
|
||||
credential_dict = credential_value
|
||||
|
||||
# Validate credential_dict
|
||||
required_credential_fields = ["client_id", "client_secret"]
|
||||
missing_fields = [
|
||||
field
|
||||
for field in required_credential_fields
|
||||
if field not in credential_dict
|
||||
]
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
"Missing required fields in credential_dict:"
|
||||
f" {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
# Construct AuthCredential
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=credential_dict["client_id"],
|
||||
client_secret=credential_dict["client_secret"],
|
||||
redirect_uri=credential_dict.get("redirect_uri", None),
|
||||
),
|
||||
)
|
||||
|
||||
return openid_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_url_to_scheme_credential(
|
||||
openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
|
||||
|
||||
Fetches OpenID configuration from the provided URL.
|
||||
|
||||
Args:
|
||||
openid_url: The OpenID Connect discovery URL.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include at least "client_id" and "client_secret", may optionally include
|
||||
"redirect_uri" and "scope"
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If the OpenID URL is invalid, fetching fails, or required
|
||||
fields are missing.
|
||||
requests.exceptions.RequestException: If there's an error during the
|
||||
HTTP request.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(openid_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
config_dict = response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Invalid JSON response from OpenID configuration endpoint"
|
||||
f" {openid_url}: {e}"
|
||||
) from e
|
||||
|
||||
# Add openIdConnectUrl to config dict
|
||||
config_dict["openIdConnectUrl"] = openid_url
|
||||
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
|
||||
|
||||
|
||||
def credential_to_param(
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
|
||||
"""Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
|
||||
|
||||
This function now supports all credential types returned by the exchangers:
|
||||
- API Key
|
||||
- HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
|
||||
- OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
|
||||
token)
|
||||
- Service Account (returns None, None, as the token is now a Bearer token)
|
||||
|
||||
Args:
|
||||
auth_scheme: The AuthScheme object.
|
||||
auth_credential: The AuthCredential object.
|
||||
|
||||
Returns:
|
||||
Tuple: (ApiParameter, Dict[str, Any])
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None, None
|
||||
|
||||
if (
|
||||
auth_scheme.type_ == AuthSchemeType.apiKey
|
||||
and auth_credential
|
||||
and auth_credential.api_key
|
||||
):
|
||||
param_name = auth_scheme.name or ""
|
||||
python_name = INTERNAL_AUTH_PREFIX + param_name
|
||||
if auth_scheme.in_ == APIKeyIn.header:
|
||||
param_location = "header"
|
||||
elif auth_scheme.in_ == APIKeyIn.query:
|
||||
param_location = "query"
|
||||
elif auth_scheme.in_ == APIKeyIn.cookie:
|
||||
param_location = "cookie"
|
||||
else:
|
||||
raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
|
||||
|
||||
param = ApiParameter(
|
||||
original_name=param_name,
|
||||
param_location=param_location,
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "",
|
||||
py_name=python_name,
|
||||
)
|
||||
kwargs = {param.py_name: auth_credential.api_key}
|
||||
return param, kwargs
|
||||
|
||||
# TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
|
||||
# Scheme
|
||||
elif (
|
||||
auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
):
|
||||
if (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
elif (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and (
|
||||
auth_credential.http.credentials.username
|
||||
or auth_credential.http.credentials.password
|
||||
)
|
||||
):
|
||||
# Basic Auth is explicitly NOT supported
|
||||
raise NotImplementedError("Basic Authentication is not supported.")
|
||||
else:
|
||||
raise ValueError("Invalid HTTP auth credentials")
|
||||
|
||||
# Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
|
||||
# Bearer tokens.
|
||||
elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
|
||||
auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
|
||||
):
|
||||
if (
|
||||
auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
return None, None
|
||||
else:
|
||||
raise ValueError("Invalid security scheme and credential combination")
|
||||
|
||||
|
||||
def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
|
||||
"""Converts a dictionary to a FastAPI AuthScheme object.
|
||||
|
||||
Args:
|
||||
data: The dictionary representing the security scheme.
|
||||
|
||||
Returns:
|
||||
A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
|
||||
HTTPBearer).
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'type' field is missing or invalid, or if the
|
||||
dictionary cannot be converted to the corresponding Pydantic model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
api_key_data = {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key",
|
||||
}
|
||||
api_key_scheme = dict_to_auth_scheme(api_key_data)
|
||||
|
||||
bearer_data = {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
}
|
||||
bearer_scheme = dict_to_auth_scheme(bearer_data)
|
||||
|
||||
|
||||
oauth2_data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
}
|
||||
}
|
||||
oauth2_scheme = dict_to_auth_scheme(oauth2_data)
|
||||
|
||||
openid_data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
|
||||
}
|
||||
openid_scheme = dict_to_auth_scheme(openid_data)
|
||||
|
||||
|
||||
```
|
||||
"""
|
||||
if "type" not in data:
|
||||
raise ValueError("Missing 'type' field in security scheme dictionary.")
|
||||
|
||||
security_type = data["type"]
|
||||
try:
|
||||
if security_type == "apiKey":
|
||||
return APIKey.model_validate(data)
|
||||
elif security_type == "http":
|
||||
if data.get("scheme") == "bearer":
|
||||
return HTTPBearer.model_validate(data)
|
||||
else:
|
||||
return HTTPBase.model_validate(data) # Generic HTTP
|
||||
elif security_type == "oauth2":
|
||||
return OAuth2.model_validate(data)
|
||||
elif security_type == "openIdConnect":
|
||||
return OpenIdConnect.model_validate(data)
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {security_type}")
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid security scheme data: {e}") from e
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
__all__ = [
|
||||
'AutoAuthCredentialExchanger',
|
||||
'BaseAuthCredentialExchanger',
|
||||
'OAuth2CredentialExchanger',
|
||||
'ServiceAccountCredentialExchanger',
|
||||
]
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
|
||||
class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Automatically selects the appropriate credential exchanger based on the auth scheme.
|
||||
|
||||
Optionally, an override can be provided to use a specific exchanger for a
|
||||
given auth scheme.
|
||||
|
||||
Example (common case):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger()
|
||||
auth_credential = exchanger.exchange_credential(
|
||||
auth_scheme=service_account_scheme,
|
||||
auth_credential=service_account_credential,
|
||||
)
|
||||
# Returns an oauth token in the form of a bearer token.
|
||||
```
|
||||
|
||||
Example (use CustomAuthExchanger for OAuth2):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers={
|
||||
AuthScheme.OAUTH2: CustomAuthExchanger,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
exchangers: A dictionary mapping auth scheme to credential exchanger class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
custom_exchangers: Optional[
|
||||
Dict[str, Type[BaseAuthCredentialExchanger]]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes the AutoAuthCredentialExchanger.
|
||||
|
||||
Args:
|
||||
custom_exchangers: Optional dictionary for adding or overriding auth
|
||||
exchangers. The key is the auth scheme, and the value is the credential
|
||||
exchanger class.
|
||||
"""
|
||||
self.exchangers = {
|
||||
AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
|
||||
}
|
||||
|
||||
if custom_exchangers:
|
||||
self.exchangers.update(custom_exchangers)
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Automatically exchanges for the credential uses the appropriate credential exchanger.
|
||||
|
||||
Args:
|
||||
auth_scheme (AuthScheme): The security scheme.
|
||||
auth_credential (AuthCredential): Optional. The authentication
|
||||
credential.
|
||||
|
||||
Returns: (AuthCredential)
|
||||
A new AuthCredential object containing the exchanged credential.
|
||||
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None
|
||||
|
||||
exchanger_class = self.exchangers.get(
|
||||
auth_credential.auth_type if auth_credential else None
|
||||
)
|
||||
|
||||
if not exchanger_class:
|
||||
return auth_credential
|
||||
|
||||
exchanger = exchanger_class()
|
||||
return exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
|
||||
|
||||
class AuthCredentialMissingError(Exception):
|
||||
"""Exception raised when required authentication credentials are missing."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BaseAuthCredentialExchanger:
|
||||
"""Base class for authentication credential exchangers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the provided authentication credential for a usable token/credential.
|
||||
|
||||
Args:
|
||||
auth_scheme: The security scheme.
|
||||
auth_credential: The authentication credential.
|
||||
|
||||
Returns:
|
||||
An updated AuthCredential object containing the fetched credential.
|
||||
For simple schemes like API key, it may return the original credential
|
||||
if no exchange is needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement exchange_credential.")
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for OpenID Connect."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_credential import HttpAuth
|
||||
from .....auth.auth_credential import HttpCredentials
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .....auth.auth_schemes import AuthSchemeType
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for OAuth2 and OpenID Connect."""
|
||||
|
||||
def _check_scheme_credential_type(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
if not auth_credential:
|
||||
raise ValueError(
|
||||
"auth_credential is empty. Please create AuthCredential using"
|
||||
" OAuth2Auth."
|
||||
)
|
||||
|
||||
if auth_scheme.type_ not in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid security scheme, expect AuthSchemeType.openIdConnect or "
|
||||
f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
|
||||
)
|
||||
|
||||
if not auth_credential.oauth2 and not auth_credential.http:
|
||||
raise ValueError(
|
||||
"auth_credential is not configured with oauth2. Please"
|
||||
" create AuthCredential and set OAuth2Auth."
|
||||
)
|
||||
|
||||
def generate_auth_token(
|
||||
self,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Args:
|
||||
auth_scheme: The OpenID Connect or OAuth2 auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP bearer access token. If the
|
||||
HTTO bearer token cannot be generated, return the origianl credential
|
||||
"""
|
||||
|
||||
if "access_token" not in auth_credential.oauth2.token:
|
||||
return auth_credential
|
||||
|
||||
# Return the access token as a bearer token.
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(
|
||||
token=auth_credential.oauth2.token["access_token"]
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the OpenID Connect auth credential for an access token or an auth URI.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP Bearer access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the auth scheme or auth credential is invalid.
|
||||
"""
|
||||
# TODO(cheliu): Implement token refresh flow
|
||||
|
||||
self._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
# If token is already HTTPBearer token, do nothing assuming that this token
|
||||
# is valid.
|
||||
if auth_credential.http:
|
||||
return auth_credential
|
||||
|
||||
# If access token is exchanged, exchange a HTTPBearer token.
|
||||
if auth_credential.oauth2.token:
|
||||
return self.generate_auth_token(auth_credential)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for Google Service Account."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import google.oauth2.credentials
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
AuthCredentialTypes,
|
||||
HttpAuth,
|
||||
HttpCredentials,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for Google Service Account.
|
||||
|
||||
Uses the default service credential if `use_default_credential = True`.
|
||||
Otherwise, uses the service account credential provided in the auth
|
||||
credential.
|
||||
"""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the service account auth credential for an access token.
|
||||
|
||||
If auth_credential contains a service account credential, it will be used
|
||||
to fetch an access token. Otherwise, the default service credential will be
|
||||
used for fetching an access token.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format, containing the access token.
|
||||
"""
|
||||
if (
|
||||
auth_credential is None
|
||||
or auth_credential.service_account is None
|
||||
or (
|
||||
auth_credential.service_account.service_account_credential is None
|
||||
and not auth_credential.service_account.use_default_credential
|
||||
)
|
||||
):
|
||||
raise AuthCredentialMissingError(
|
||||
"Service account credentials are missing. Please provide them, or set"
|
||||
" `use_default_credential = True` to use application default"
|
||||
" credential in a hosted service like Cloud Run."
|
||||
)
|
||||
|
||||
try:
|
||||
if auth_credential.service_account.use_default_credential:
|
||||
credentials, _ = google.auth.default()
|
||||
else:
|
||||
config = auth_credential.service_account
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
config.service_account_credential.model_dump(), scopes=config.scopes
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credentials.token),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
except Exception as e:
|
||||
raise AuthCredentialMissingError(
|
||||
f"Failed to exchange service account token: {e}"
|
||||
) from e
|
||||
19
src/google/adk/tools/openapi_tool/common/__init__.py
Normal file
19
src/google/adk/tools/openapi_tool/common/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import common
|
||||
|
||||
__all__ = [
|
||||
'common',
|
||||
]
|
||||
300
src/google/adk/tools/openapi_tool/common/common.py
Normal file
300
src/google/adk/tools/openapi_tool/common/common.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import keyword
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_serializer
|
||||
|
||||
|
||||
def to_snake_case(text: str) -> str:
|
||||
"""Converts a string into snake_case.
|
||||
|
||||
Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
|
||||
(e.g., "REST API") and consecutive uppercase letters correctly. Also handles
|
||||
mixed cases with and without spaces.
|
||||
|
||||
Examples:
|
||||
```
|
||||
to_snake_case('camelCase') -> 'camel_case'
|
||||
to_snake_case('UpperCamelCase') -> 'upper_camel_case'
|
||||
to_snake_case('space separated') -> 'space_separated'
|
||||
```
|
||||
|
||||
Args:
|
||||
text: The input string.
|
||||
|
||||
Returns:
|
||||
The snake_case version of the string.
|
||||
"""
|
||||
|
||||
# Handle spaces and non-alphanumeric characters (replace with underscores)
|
||||
text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
|
||||
|
||||
# Insert underscores before uppercase letters (handling both CamelCases)
|
||||
text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
|
||||
text = re.sub(
|
||||
r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
|
||||
) # UpperCamelCase and acronyms
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove consecutive underscores (clean up extra underscores)
|
||||
text = re.sub(r'_+', '_', text)
|
||||
|
||||
# Remove leading and trailing underscores
|
||||
text = text.strip('_')
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
|
||||
"""Renames Python keywords by adding a prefix.
|
||||
|
||||
Example:
|
||||
```
|
||||
rename_python_keywords('if') -> 'param_if'
|
||||
rename_python_keywords('for') -> 'param_for'
|
||||
```
|
||||
|
||||
Args:
|
||||
s: The input string.
|
||||
prefix: The prefix to add to the keyword.
|
||||
|
||||
Returns:
|
||||
The renamed string.
|
||||
"""
|
||||
if keyword.iskeyword(s):
|
||||
return prefix + s
|
||||
return s
|
||||
|
||||
|
||||
class ApiParameter(BaseModel):
|
||||
"""Data class representing a function parameter."""
|
||||
|
||||
original_name: str
|
||||
param_location: str
|
||||
param_schema: Union[str, Schema]
|
||||
description: Optional[str] = ''
|
||||
py_name: Optional[str] = ''
|
||||
type_value: type[Any] = Field(default=None, init_var=False)
|
||||
type_hint: str = Field(default=None, init_var=False)
|
||||
|
||||
def model_post_init(self, _: Any):
|
||||
self.py_name = (
|
||||
self.py_name
|
||||
if self.py_name
|
||||
else rename_python_keywords(to_snake_case(self.original_name))
|
||||
)
|
||||
if isinstance(self.param_schema, str):
|
||||
self.param_schema = Schema.model_validate_json(self.param_schema)
|
||||
|
||||
self.description = self.description or self.param_schema.description or ''
|
||||
self.type_value = TypeHintHelper.get_type_value(self.param_schema)
|
||||
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
|
||||
return self
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return {
|
||||
'original_name': self.original_name,
|
||||
'param_location': self.param_location,
|
||||
'param_schema': self.param_schema,
|
||||
'description': self.description,
|
||||
'py_name': self.py_name,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.py_name}: {self.type_hint}'
|
||||
|
||||
def to_arg_string(self):
|
||||
"""Converts the parameter to an argument string for function call."""
|
||||
return f'{self.py_name}={self.py_name}'
|
||||
|
||||
def to_dict_property(self):
|
||||
"""Converts the parameter to a key:value string for dict property."""
|
||||
return f'"{self.py_name}": {self.py_name}'
|
||||
|
||||
def to_pydoc_string(self):
|
||||
"""Converts the parameter to a PyDoc parameter docstr."""
|
||||
return PydocHelper.generate_param_doc(self)
|
||||
|
||||
|
||||
class TypeHintHelper:
|
||||
"""Helper class for generating type hints."""
|
||||
|
||||
@staticmethod
|
||||
def get_type_value(schema: Schema) -> Any:
|
||||
"""Generates the Python type value for a given parameter."""
|
||||
param_type = schema.type if schema.type else Any
|
||||
|
||||
if param_type == 'integer':
|
||||
return int
|
||||
elif param_type == 'number':
|
||||
return float
|
||||
elif param_type == 'boolean':
|
||||
return bool
|
||||
elif param_type == 'string':
|
||||
return str
|
||||
elif param_type == 'array':
|
||||
items_type = Any
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return List[Dict[str, Any]]
|
||||
else:
|
||||
type_map = {
|
||||
'integer': int,
|
||||
'number': float,
|
||||
'boolean': bool,
|
||||
'string': str,
|
||||
'object': Dict[str, Any],
|
||||
'array': List[Any],
|
||||
}
|
||||
return List[type_map.get(items_type, 'Any')]
|
||||
elif param_type == 'object':
|
||||
return Dict[str, Any]
|
||||
else:
|
||||
return Any
|
||||
|
||||
@staticmethod
|
||||
def get_type_hint(schema: Schema) -> str:
|
||||
"""Generates the Python type in string for a given parameter."""
|
||||
param_type = schema.type if schema.type else 'Any'
|
||||
|
||||
if param_type == 'integer':
|
||||
return 'int'
|
||||
elif param_type == 'number':
|
||||
return 'float'
|
||||
elif param_type == 'boolean':
|
||||
return 'bool'
|
||||
elif param_type == 'string':
|
||||
return 'str'
|
||||
elif param_type == 'array':
|
||||
items_type = 'Any'
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return 'List[Dict[str, Any]]'
|
||||
else:
|
||||
type_map = {
|
||||
'integer': 'int',
|
||||
'number': 'float',
|
||||
'boolean': 'bool',
|
||||
'string': 'str',
|
||||
}
|
||||
return f"List[{type_map.get(items_type, 'Any')}]"
|
||||
elif param_type == 'object':
|
||||
return 'Dict[str, Any]'
|
||||
else:
|
||||
return 'Any'
|
||||
|
||||
|
||||
class PydocHelper:
|
||||
"""Helper class for generating PyDoc strings."""
|
||||
|
||||
@staticmethod
|
||||
def generate_param_doc(
|
||||
param: ApiParameter,
|
||||
) -> str:
|
||||
"""Generates a parameter documentation string.
|
||||
|
||||
Args:
|
||||
param: ApiParameter - The parameter to generate the documentation for.
|
||||
|
||||
Returns:
|
||||
str: The generated parameter Python documentation string.
|
||||
"""
|
||||
description = param.description.strip() if param.description else ''
|
||||
param_doc = f'{param.py_name} ({param.type_hint}): {description}'
|
||||
|
||||
if param.param_schema.type == 'object':
|
||||
properties = param.param_schema.properties
|
||||
if properties:
|
||||
param_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
|
||||
return param_doc
|
||||
|
||||
@staticmethod
|
||||
def generate_return_doc(responses: Dict[str, Response]) -> str:
|
||||
"""Generates a return value documentation string.
|
||||
|
||||
Args:
|
||||
responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
|
||||
Operation
|
||||
|
||||
Returns:
|
||||
str: The generated return value Python documentation string.
|
||||
"""
|
||||
return_doc = ''
|
||||
|
||||
# Only consider 2xx responses for return type hinting.
|
||||
# Returns the 2xx response with the smallest status code number and with
|
||||
# content defined.
|
||||
sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
|
||||
qualified_response = next(
|
||||
filter(
|
||||
lambda r: r[0].startswith('2') and r[1].content,
|
||||
sorted_responses,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not qualified_response:
|
||||
return ''
|
||||
response_details = qualified_response[1]
|
||||
|
||||
description = (response_details.description or '').strip()
|
||||
content = response_details.content or {}
|
||||
|
||||
# Generate return type hint and properties for the first response type.
|
||||
# TODO(cheliu): Handle multiple content types.
|
||||
for _, schema_details in content.items():
|
||||
schema = schema_details.schema_ or {}
|
||||
|
||||
# Use a dummy Parameter object for return type hinting.
|
||||
dummy_param = ApiParameter(
|
||||
original_name='', param_location='', param_schema=schema
|
||||
)
|
||||
return_doc = f'Returns ({dummy_param.type_hint}): {description}'
|
||||
|
||||
response_type = schema.type or 'Any'
|
||||
if response_type != 'object':
|
||||
break
|
||||
properties = schema.properties
|
||||
if not properties:
|
||||
break
|
||||
return_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
break
|
||||
|
||||
return return_doc
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
|
||||
from .openapi_toolset import OpenAPIToolset
|
||||
from .operation_parser import OperationParser
|
||||
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
__all__ = [
|
||||
'OpenApiSpecParser',
|
||||
'OperationEndpoint',
|
||||
'ParsedOperation',
|
||||
'OpenAPIToolset',
|
||||
'OperationParser',
|
||||
'RestApiTool',
|
||||
'to_gemini_schema',
|
||||
'snake_to_lower_camel',
|
||||
'AuthPreparationState',
|
||||
'ToolAuthHandler',
|
||||
]
|
||||
@@ -0,0 +1,231 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .operation_parser import OperationParser
|
||||
|
||||
|
||||
class OperationEndpoint(BaseModel):
|
||||
base_url: str
|
||||
path: str
|
||||
method: str
|
||||
|
||||
|
||||
class ParsedOperation(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
endpoint: OperationEndpoint
|
||||
operation: Operation
|
||||
parameters: List[ApiParameter]
|
||||
return_value: ApiParameter
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
additional_context: Optional[Any] = None
|
||||
|
||||
|
||||
class OpenApiSpecParser:
|
||||
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
|
||||
|
||||
This class takes an OpenApiOperation object and provides methods to generate:
|
||||
1. A string representation of a Python function that handles the operation.
|
||||
2. A JSON schema representing the input parameters of the operation.
|
||||
3. A callable Python object (a function) that can execute the operation.
|
||||
"""
|
||||
|
||||
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
|
||||
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
|
||||
|
||||
ParsedOperation objects are further used for generating RestApiTool.
|
||||
|
||||
Args:
|
||||
openapi_spec_dict: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A list of ParsedOperation objects.
|
||||
"""
|
||||
|
||||
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
|
||||
operations = self._collect_operations(openapi_spec_dict)
|
||||
return operations
|
||||
|
||||
def _collect_operations(
|
||||
self, openapi_spec: Dict[str, Any]
|
||||
) -> List[ParsedOperation]:
|
||||
"""Collects operations from an OpenAPI spec."""
|
||||
operations = []
|
||||
|
||||
# Taking first server url, or default to empty string if not present
|
||||
base_url = ""
|
||||
if openapi_spec.get("servers"):
|
||||
base_url = openapi_spec["servers"][0].get("url", "")
|
||||
|
||||
# Get global security scheme (if any)
|
||||
global_scheme_name = None
|
||||
if openapi_spec.get("security"):
|
||||
# Use first scheme by default.
|
||||
scheme_names = list(openapi_spec["security"][0].keys())
|
||||
global_scheme_name = scheme_names[0] if scheme_names else None
|
||||
|
||||
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
|
||||
|
||||
for path, path_item in openapi_spec.get("paths", {}).items():
|
||||
if path_item is None:
|
||||
continue
|
||||
|
||||
for method in (
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"head",
|
||||
"options",
|
||||
"trace",
|
||||
):
|
||||
operation_dict = path_item.get(method)
|
||||
if operation_dict is None:
|
||||
continue
|
||||
|
||||
# If operation ID is missing, assign an operation id based on path
|
||||
# and method
|
||||
if "operationId" not in operation_dict:
|
||||
temp_id = to_snake_case(f"{path}_{method}")
|
||||
operation_dict["operationId"] = temp_id
|
||||
|
||||
url = OperationEndpoint(base_url=base_url, path=path, method=method)
|
||||
operation = Operation.model_validate(operation_dict)
|
||||
operation_parser = OperationParser(operation)
|
||||
|
||||
# Check for operation-specific auth scheme
|
||||
auth_scheme_name = operation_parser.get_auth_scheme_name()
|
||||
auth_scheme_name = (
|
||||
auth_scheme_name if auth_scheme_name else global_scheme_name
|
||||
)
|
||||
auth_scheme = (
|
||||
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
|
||||
)
|
||||
|
||||
parsed_op = ParsedOperation(
|
||||
name=operation_parser.get_function_name(),
|
||||
description=operation.description or operation.summary or "",
|
||||
endpoint=url,
|
||||
operation=operation,
|
||||
parameters=operation_parser.get_parameters(),
|
||||
return_value=operation_parser.get_return_value(),
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=None, # Placeholder
|
||||
additional_context={},
|
||||
)
|
||||
operations.append(parsed_op)
|
||||
|
||||
return operations
|
||||
|
||||
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively resolves all $ref references in an OpenAPI specification.
|
||||
|
||||
Handles circular references correctly.
|
||||
|
||||
Args:
|
||||
openapi_spec: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the OpenAPI specification with all references
|
||||
resolved.
|
||||
"""
|
||||
|
||||
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
|
||||
resolved_cache = {} # Cache resolved references
|
||||
|
||||
def resolve_ref(ref_string, current_doc):
|
||||
"""Resolves a single $ref string."""
|
||||
parts = ref_string.split("/")
|
||||
if parts[0] != "#":
|
||||
raise ValueError(f"External references not supported: {ref_string}")
|
||||
|
||||
current = current_doc
|
||||
for part in parts[1:]:
|
||||
if part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None # Reference not found
|
||||
return current
|
||||
|
||||
def recursive_resolve(obj, current_doc, seen_refs=None):
|
||||
"""Recursively resolves references, handling circularity.
|
||||
|
||||
Args:
|
||||
obj: The object to traverse.
|
||||
current_doc: Document to search for refs.
|
||||
seen_refs: A set to track already-visited references (for circularity
|
||||
detection).
|
||||
|
||||
Returns:
|
||||
The resolved object.
|
||||
"""
|
||||
if seen_refs is None:
|
||||
seen_refs = set() # Initialize the set if it's the first call
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and isinstance(obj["$ref"], str):
|
||||
ref_string = obj["$ref"]
|
||||
|
||||
# Check for circularity
|
||||
if ref_string in seen_refs and ref_string not in resolved_cache:
|
||||
# Circular reference detected! Return a *copy* of the object,
|
||||
# but *without* the $ref. This breaks the cycle while
|
||||
# still maintaining the overall structure.
|
||||
return {k: v for k, v in obj.items() if k != "$ref"}
|
||||
|
||||
seen_refs.add(ref_string) # Add the reference to the set
|
||||
|
||||
# Check if we have a cached resolved value
|
||||
if ref_string in resolved_cache:
|
||||
return copy.deepcopy(resolved_cache[ref_string])
|
||||
|
||||
resolved_value = resolve_ref(ref_string, current_doc)
|
||||
if resolved_value is not None:
|
||||
# Recursively resolve the *resolved* value,
|
||||
# passing along the 'seen_refs' set
|
||||
resolved_value = recursive_resolve(
|
||||
resolved_value, current_doc, seen_refs
|
||||
)
|
||||
resolved_cache[ref_string] = resolved_value
|
||||
return copy.deepcopy(resolved_value) # return the cached result
|
||||
else:
|
||||
return obj # return original if no resolved value.
|
||||
|
||||
else:
|
||||
new_dict = {}
|
||||
for key, value in obj.items():
|
||||
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, list):
|
||||
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
return recursive_resolve(openapi_spec, openapi_spec)
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from .openapi_spec_parser import OpenApiSpecParser
|
||||
from .rest_api_tool import RestApiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIToolset:
|
||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
spec_dict: Optional[Dict[str, Any]] = None,
|
||||
spec_str: Optional[str] = None,
|
||||
spec_str_type: Literal["json", "yaml"] = "json",
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
|
||||
instead of loading the spec from a string.
|
||||
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
|
||||
when spec_dict is not provided.
|
||||
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
|
||||
"yaml".
|
||||
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
|
||||
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
auth_credential: The auth credential to use for all tools. Use
|
||||
AuthCredential or use helpers in
|
||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
"""
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
):
|
||||
"""Configure auth scheme and credential for all tools."""
|
||||
|
||||
for tool in self.tools:
|
||||
if auth_scheme:
|
||||
tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
|
||||
def _load_spec(
|
||||
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads the OpenAPI spec string into adictionary."""
|
||||
if spec_type == "json":
|
||||
return json.loads(spec_str)
|
||||
elif spec_type == "yaml":
|
||||
return yaml.safe_load(spec_str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported spec type: {spec_type}")
|
||||
|
||||
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
|
||||
"""Parse OpenAPI spec into a list of RestApiTool."""
|
||||
operations = OpenApiSpecParser().parse(openapi_spec_dict)
|
||||
|
||||
tools = []
|
||||
for o in operations:
|
||||
tool = RestApiTool.from_parsed_operation(o)
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
@@ -0,0 +1,260 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import Schema
|
||||
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import PydocHelper
|
||||
from ..common.common import to_snake_case
|
||||
|
||||
|
||||
class OperationParser:
|
||||
"""Generates parameters for Python functions from an OpenAPI operation.
|
||||
|
||||
This class processes an OpenApiOperation object and provides helper methods
|
||||
to extract information needed to generate Python function declarations,
|
||||
docstrings, signatures, and JSON schemas. It handles parameter processing,
|
||||
name deduplication, and type hint generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
|
||||
):
|
||||
"""Initializes the OperationParser with an OpenApiOperation.
|
||||
|
||||
Args:
|
||||
operation: The OpenApiOperation object or a dictionary to process.
|
||||
should_parse: Whether to parse the operation during initialization.
|
||||
"""
|
||||
if isinstance(operation, dict):
|
||||
self.operation = Operation.model_validate(operation)
|
||||
elif isinstance(operation, str):
|
||||
self.operation = Operation.model_validate_json(operation)
|
||||
else:
|
||||
self.operation = operation
|
||||
|
||||
self.params: List[ApiParameter] = []
|
||||
self.return_value: Optional[ApiParameter] = None
|
||||
if should_parse:
|
||||
self._process_operation_parameters()
|
||||
self._process_request_body()
|
||||
self._process_return_value()
|
||||
self._dedupe_param_names()
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
operation: Union[Operation, Dict[str, Any]],
|
||||
params: List[ApiParameter],
|
||||
return_value: Optional[ApiParameter] = None,
|
||||
) -> 'OperationParser':
|
||||
parser = cls(operation, should_parse=False)
|
||||
parser.params = params
|
||||
parser.return_value = return_value
|
||||
return parser
|
||||
|
||||
def _process_operation_parameters(self):
|
||||
"""Processes parameters from the OpenAPI operation."""
|
||||
parameters = self.operation.parameters or []
|
||||
for param in parameters:
|
||||
if isinstance(param, Parameter):
|
||||
original_name = param.name
|
||||
description = param.description or ''
|
||||
location = param.in_ or ''
|
||||
schema = param.schema_ or {} # Use schema_ instead of .schema
|
||||
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=original_name,
|
||||
param_location=location,
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_request_body(self):
|
||||
"""Processes the request body from the OpenAPI operation."""
|
||||
request_body = self.operation.requestBody
|
||||
if not request_body:
|
||||
return
|
||||
|
||||
content = request_body.content or {}
|
||||
if not content:
|
||||
return
|
||||
|
||||
# If request body is an object, expand the properties as parameters
|
||||
for _, media_type_object in content.items():
|
||||
schema = media_type_object.schema_ or {}
|
||||
description = request_body.description or ''
|
||||
|
||||
if schema and schema.type == 'object':
|
||||
for prop_name, prop_details in schema.properties.items():
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=prop_name,
|
||||
param_location='body',
|
||||
param_schema=prop_details,
|
||||
description=prop_details.description,
|
||||
)
|
||||
)
|
||||
|
||||
elif schema and schema.type == 'array':
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name='array',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.params.append(
|
||||
# Empty name for unnamed body param
|
||||
ApiParameter(
|
||||
original_name='',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
break # Process first mime type only
|
||||
|
||||
def _dedupe_param_names(self):
|
||||
"""Deduplicates parameter names to avoid conflicts."""
|
||||
params_cnt = {}
|
||||
for param in self.params:
|
||||
name = param.py_name
|
||||
if name not in params_cnt:
|
||||
params_cnt[name] = 0
|
||||
else:
|
||||
params_cnt[name] += 1
|
||||
param.py_name = f'{name}_{params_cnt[name] -1}'
|
||||
|
||||
def _process_return_value(self) -> Parameter:
|
||||
"""Returns a Parameter object representing the return type."""
|
||||
responses = self.operation.responses or {}
|
||||
# Default to Any if no 2xx response or if schema is missing
|
||||
return_schema = Schema(type='Any')
|
||||
|
||||
# Take the 20x response with the smallest response code.
|
||||
valid_codes = list(
|
||||
filter(lambda k: k.startswith('2'), list(responses.keys()))
|
||||
)
|
||||
min_20x_status_code = min(valid_codes) if valid_codes else None
|
||||
|
||||
if min_20x_status_code and responses[min_20x_status_code].content:
|
||||
content = responses[min_20x_status_code].content
|
||||
for mime_type in content:
|
||||
if content[mime_type].schema_:
|
||||
return_schema = content[mime_type].schema_
|
||||
break
|
||||
|
||||
self.return_value = ApiParameter(
|
||||
original_name='',
|
||||
param_location='',
|
||||
param_schema=return_schema,
|
||||
)
|
||||
|
||||
def get_function_name(self) -> str:
|
||||
"""Returns the generated function name."""
|
||||
operation_id = self.operation.operationId
|
||||
if not operation_id:
|
||||
raise ValueError('Operation ID is missing')
|
||||
return to_snake_case(operation_id)[:60]
|
||||
|
||||
def get_return_type_hint(self) -> str:
|
||||
"""Returns the return type hint string (like 'str', 'int', etc.)."""
|
||||
return self.return_value.type_hint
|
||||
|
||||
def get_return_type_value(self) -> Any:
|
||||
"""Returns the return type value (like str, int, List[str], etc.)."""
|
||||
return self.return_value.type_value
|
||||
|
||||
def get_parameters(self) -> List[ApiParameter]:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.params
|
||||
|
||||
def get_return_value(self) -> ApiParameter:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.return_value
|
||||
|
||||
def get_auth_scheme_name(self) -> str:
|
||||
"""Returns the name of the auth scheme for this operation from the spec."""
|
||||
if self.operation.security:
|
||||
scheme_name = list(self.operation.security[0].keys())[0]
|
||||
return scheme_name
|
||||
return ''
|
||||
|
||||
def get_pydoc_string(self) -> str:
|
||||
"""Returns the generated PyDoc string."""
|
||||
pydoc_params = [param.to_pydoc_string() for param in self.params]
|
||||
pydoc_description = (
|
||||
self.operation.summary or self.operation.description or ''
|
||||
)
|
||||
pydoc_return = PydocHelper.generate_return_doc(
|
||||
self.operation.responses or {}
|
||||
)
|
||||
pydoc_arg_list = chr(10).join(
|
||||
f' {param_doc}' for param_doc in pydoc_params
|
||||
)
|
||||
return dedent(f"""
|
||||
\"\"\"{pydoc_description}
|
||||
|
||||
Args:
|
||||
{pydoc_arg_list}
|
||||
|
||||
{pydoc_return}
|
||||
\"\"\"
|
||||
""").strip()
|
||||
|
||||
def get_json_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the JSON schema for the function arguments."""
|
||||
properties = {
|
||||
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
|
||||
for p in self.params
|
||||
}
|
||||
return {
|
||||
'properties': properties,
|
||||
'required': [p.py_name for p in self.params],
|
||||
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
def get_signature_parameters(self) -> List[inspect.Parameter]:
|
||||
"""Returns a list of inspect.Parameter objects for the function."""
|
||||
return [
|
||||
inspect.Parameter(
|
||||
param.py_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=param.type_value,
|
||||
)
|
||||
for param in self.params
|
||||
]
|
||||
|
||||
def get_annotations(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of parameter annotations for the function."""
|
||||
annotations = {p.py_name: p.type_value for p in self.params}
|
||||
annotations['return'] = self.get_return_type_value()
|
||||
return annotations
|
||||
@@ -0,0 +1,496 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....tools import BaseTool
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.auth_helpers import credential_to_param
|
||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .openapi_spec_parser import OperationEndpoint
|
||||
from .openapi_spec_parser import ParsedOperation
|
||||
from .operation_parser import OperationParser
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
|
||||
def snake_to_lower_camel(snake_case_string: str):
|
||||
"""Converts a snake_case string to a lower_camel_case string.
|
||||
|
||||
Args:
|
||||
snake_case_string: The input snake_case string.
|
||||
|
||||
Returns:
|
||||
The lower_camel_case string.
|
||||
"""
|
||||
if "_" not in snake_case_string:
|
||||
return snake_case_string
|
||||
|
||||
return "".join([
|
||||
s.lower() if i == 0 else s.capitalize()
|
||||
for i, s in enumerate(snake_case_string.split("_"))
|
||||
])
|
||||
|
||||
|
||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dictionary.
|
||||
|
||||
Returns:
|
||||
A Pydantic Schema object. Returns None if input is None.
|
||||
Raises TypeError if input is not a dict.
|
||||
"""
|
||||
if openapi_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(openapi_schema, dict):
|
||||
raise TypeError("openapi_schema must be a dictionary")
|
||||
|
||||
pydantic_schema_data = {}
|
||||
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
if not openapi_schema.get("type"):
|
||||
openapi_schema["type"] = "object"
|
||||
|
||||
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
||||
# See b/385165182
|
||||
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
||||
"properties"
|
||||
):
|
||||
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
||||
|
||||
for key, value in openapi_schema.items():
|
||||
snake_case_key = to_snake_case(key)
|
||||
# Check if the snake_case_key exists in the Schema model's fields.
|
||||
if snake_case_key in Schema.model_fields:
|
||||
if snake_case_key in ["title", "default", "format"]:
|
||||
# Ignore these fields as Gemini backend doesn't recognize them, and will
|
||||
# throw exception if they appear in the schema.
|
||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||
# supported for STRING type
|
||||
continue
|
||||
if snake_case_key == "properties" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = {
|
||||
k: to_gemini_schema(v) for k, v in value.items()
|
||||
}
|
||||
elif snake_case_key == "items" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
elif snake_case_key == "any_of" and isinstance(value, list):
|
||||
pydantic_schema_data[snake_case_key] = [
|
||||
to_gemini_schema(item) for item in value
|
||||
]
|
||||
# Important: Handle cases where the OpenAPI schema might contain lists
|
||||
# or other structures that need to be recursively processed.
|
||||
elif isinstance(value, list) and snake_case_key not in (
|
||||
"enum",
|
||||
"required",
|
||||
"property_ordering",
|
||||
):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_list.append(to_gemini_schema(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
pydantic_schema_data[snake_case_key] = new_list
|
||||
elif isinstance(value, dict) and snake_case_key not in ("properties"):
|
||||
# Handle dictionary which is neither properties or items
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
else:
|
||||
# Simple value assignment (int, str, bool, etc.)
|
||||
pydantic_schema_data[snake_case_key] = value
|
||||
|
||||
return Schema(**pydantic_schema_data)
|
||||
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class RestApiTool(BaseTool):
|
||||
"""A generic tool that interacts with a REST API.
|
||||
|
||||
* Generates request params and body
|
||||
* Attaches auth credentials to API call.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
endpoint: Union[OperationEndpoint, str],
|
||||
operation: Union[Operation, str],
|
||||
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||
should_parse_operation=True,
|
||||
):
|
||||
"""Initializes the RestApiTool with the given parameters.
|
||||
|
||||
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
|
||||
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
|
||||
auth_scheme and auth_credential.
|
||||
|
||||
Args:
|
||||
name: The name of the tool.
|
||||
description: The description of the tool.
|
||||
endpoint: Include the base_url, path, and method of the tool.
|
||||
operation: Pydantic object or a dict. Representing the OpenAPI Operation
|
||||
object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
|
||||
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
|
||||
SecurityScheme object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
|
||||
auth_credential: The authentication credential of the tool.
|
||||
should_parse_operation: Whether to parse the operation.
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
self.name = name[:60]
|
||||
self.description = description
|
||||
self.endpoint = (
|
||||
OperationEndpoint.model_validate_json(endpoint)
|
||||
if isinstance(endpoint, str)
|
||||
else endpoint
|
||||
)
|
||||
self.operation = (
|
||||
Operation.model_validate_json(operation)
|
||||
if isinstance(operation, str)
|
||||
else operation
|
||||
)
|
||||
self.auth_credential, self.auth_scheme = None, None
|
||||
|
||||
self.configure_auth_credential(auth_credential)
|
||||
self.configure_auth_scheme(auth_scheme)
|
||||
|
||||
# Private properties
|
||||
self.credential_exchanger = AutoAuthCredentialExchanger()
|
||||
if should_parse_operation:
|
||||
self._operation_parser = OperationParser(self.operation)
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a ParsedOperation object.
|
||||
|
||||
Args:
|
||||
parsed: A ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation_parser = OperationParser.load(
|
||||
parsed.operation, parsed.parameters, parsed.return_value
|
||||
)
|
||||
|
||||
tool_name = to_snake_case(operation_parser.get_function_name())
|
||||
generated = cls(
|
||||
name=tool_name,
|
||||
description=parsed.operation.description
|
||||
or parsed.operation.summary
|
||||
or "",
|
||||
endpoint=parsed.endpoint,
|
||||
operation=parsed.operation,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
auth_credential=parsed.auth_credential,
|
||||
)
|
||||
generated._operation_parser = operation_parser
|
||||
return generated
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation_str(
|
||||
cls, parsed_operation_str: str
|
||||
) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a dict.
|
||||
|
||||
Args:
|
||||
parsed: A dict representation of a ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation = ParsedOperation.model_validate_json(parsed_operation_str)
|
||||
return RestApiTool.from_parsed_operation(operation)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._operation_parser.get_json_schema()
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
def configure_auth_scheme(
|
||||
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
|
||||
):
|
||||
"""Configures the authentication scheme for the API call.
|
||||
|
||||
Args:
|
||||
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
|
||||
converted to a AuthScheme object.
|
||||
"""
|
||||
if isinstance(auth_scheme, dict):
|
||||
auth_scheme = dict_to_auth_scheme(auth_scheme)
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
def configure_auth_credential(
|
||||
self, auth_credential: Optional[Union[AuthCredential, str]] = None
|
||||
):
|
||||
"""Configures the authentication credential for the API call.
|
||||
|
||||
Args:
|
||||
auth_credential: AuthCredential|dict - The authentication credential.
|
||||
The dict is converted to an AuthCredential object.
|
||||
"""
|
||||
if isinstance(auth_credential, str):
|
||||
auth_credential = AuthCredential.model_validate_json(auth_credential)
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
def _prepare_auth_request_params(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
|
||||
# Handle Authentication
|
||||
if not auth_scheme or not auth_credential:
|
||||
return
|
||||
|
||||
return credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
def _prepare_request_params(
|
||||
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepares the request parameters for the API call.
|
||||
|
||||
Args:
|
||||
parameters: A list of ApiParameter objects representing the parameters
|
||||
for the API call.
|
||||
kwargs: The keyword arguments passed to the call function from the Tool
|
||||
caller.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the request parameters for the API call. This
|
||||
initializes a requests.request() call.
|
||||
|
||||
Example:
|
||||
self._prepare_request_params({"input_id": "test-id"})
|
||||
"""
|
||||
method = self.endpoint.method.lower()
|
||||
if not method:
|
||||
raise ValueError("Operation method not found.")
|
||||
|
||||
path_params: Dict[str, Any] = {}
|
||||
query_params: Dict[str, Any] = {}
|
||||
header_params: Dict[str, Any] = {}
|
||||
cookie_params: Dict[str, Any] = {}
|
||||
|
||||
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
|
||||
|
||||
# Fill in path, query, header and cookie parameters to the request
|
||||
for param_k, v in kwargs.items():
|
||||
param_obj = params_map.get(param_k)
|
||||
if not param_obj:
|
||||
continue # If input arg not in the ApiParameter list, ignore it.
|
||||
|
||||
original_k = param_obj.original_name
|
||||
param_location = param_obj.param_location
|
||||
|
||||
if param_location == "path":
|
||||
path_params[original_k] = v
|
||||
elif param_location == "query":
|
||||
if v:
|
||||
query_params[original_k] = v
|
||||
elif param_location == "header":
|
||||
header_params[original_k] = v
|
||||
elif param_location == "cookie":
|
||||
cookie_params[original_k] = v
|
||||
|
||||
# Construct URL
|
||||
base_url = self.endpoint.base_url or ""
|
||||
base_url = base_url[:-1] if base_url.endswith("/") else base_url
|
||||
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
|
||||
|
||||
# Construct body
|
||||
body_kwargs: Dict[str, Any] = {}
|
||||
request_body = self.operation.requestBody
|
||||
if request_body:
|
||||
for mime_type, media_type_object in request_body.content.items():
|
||||
schema = media_type_object.schema_
|
||||
body_data = None
|
||||
|
||||
if schema.type == "object":
|
||||
body_data = {}
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name in kwargs:
|
||||
body_data[param.original_name] = kwargs[param.py_name]
|
||||
|
||||
elif schema.type == "array":
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name == "array":
|
||||
body_data = kwargs.get("array")
|
||||
break
|
||||
else: # like string
|
||||
for param in parameters:
|
||||
# original_name = '' indicating this param applies to the full body.
|
||||
if param.param_location == "body" and not param.original_name:
|
||||
body_data = (
|
||||
kwargs.get(param.py_name) if param.py_name in kwargs else None
|
||||
)
|
||||
break
|
||||
|
||||
if mime_type == "application/json" or mime_type.endswith("+json"):
|
||||
if body_data is not None:
|
||||
body_kwargs["json"] = body_data
|
||||
elif mime_type == "application/x-www-form-urlencoded":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "multipart/form-data":
|
||||
body_kwargs["files"] = body_data
|
||||
elif mime_type == "application/octet-stream":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "text/plain":
|
||||
body_kwargs["data"] = body_data
|
||||
|
||||
if mime_type:
|
||||
header_params["Content-Type"] = mime_type
|
||||
break # Process only the first mime_type
|
||||
|
||||
filtered_query_params: Dict[str, Any] = {
|
||||
k: v for k, v in query_params.items() if v is not None
|
||||
}
|
||||
|
||||
request_params: Dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"params": filtered_query_params,
|
||||
"headers": header_params,
|
||||
"cookies": cookie_params,
|
||||
**body_kwargs,
|
||||
}
|
||||
|
||||
return request_params
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
return self.call(args=args, tool_context=tool_context)
|
||||
|
||||
def call(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
"""Executes the REST API call.
|
||||
|
||||
Args:
|
||||
args: Keyword arguments representing the operation parameters.
|
||||
tool_context: The tool context (not used here, but required by the
|
||||
interface).
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary.
|
||||
"""
|
||||
# Prepare auth credentials for the API call
|
||||
tool_auth_handler = ToolAuthHandler.from_tool_context(
|
||||
tool_context, self.auth_scheme, self.auth_credential
|
||||
)
|
||||
auth_result = tool_auth_handler.prepare_auth_credentials()
|
||||
auth_state, auth_scheme, auth_credential = (
|
||||
auth_result.state,
|
||||
auth_result.auth_scheme,
|
||||
auth_result.auth_credential,
|
||||
)
|
||||
|
||||
if auth_state == "pending":
|
||||
return {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
# Attach parameters from auth into main parameters list
|
||||
api_params, api_args = self._operation_parser.get_parameters().copy(), args
|
||||
if auth_credential:
|
||||
# Attach parameters from auth into main parameters list
|
||||
auth_param, auth_args = self._prepare_auth_request_params(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
if auth_param and auth_args:
|
||||
api_params = [auth_param] + api_params
|
||||
api_args.update(auth_args)
|
||||
|
||||
# Got all parameters. Call the API.
|
||||
request_params = self._prepare_request_params(api_params, api_args)
|
||||
response = requests.request(**request_params)
|
||||
|
||||
# Parse API response
|
||||
try:
|
||||
response.raise_for_status() # Raise HTTPError for bad responses
|
||||
return response.json() # Try to decode JSON
|
||||
except requests.exceptions.HTTPError:
|
||||
error_details = response.content.decode("utf-8")
|
||||
return {
|
||||
"error": (
|
||||
f"Tool {self.name} execution failed. Analyze this execution error"
|
||||
" and your inputs. Retry with adjustments if applicable. But"
|
||||
" make sure don't retry more than 3 times. Execution Error:"
|
||||
f" {error_details}"
|
||||
)
|
||||
}
|
||||
except ValueError:
|
||||
return {"text": response.text} # Return text if not JSON
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}")'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}", operation="{self.operation}",'
|
||||
f' auth_scheme="{self.auth_scheme}",'
|
||||
f' auth_credential="{self.auth_credential}")'
|
||||
)
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_tool import AuthConfig
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class AuthPreparationResult(BaseModel):
|
||||
"""Result of the credential preparation process."""
|
||||
|
||||
state: AuthPreparationState
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
|
||||
|
||||
class ToolContextCredentialStore:
|
||||
"""Handles storage and retrieval of credentials within a ToolContext."""
|
||||
|
||||
def __init__(self, tool_context: ToolContext):
|
||||
self.tool_context = tool_context
|
||||
|
||||
def get_credential_key(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
# no need to prepend temp: namespace, session state is a copy, changes to
|
||||
# it won't be persisted , only changes in event_action.state_delta will be
|
||||
# persisted. temp: namespace will be cleared after current run. but tool
|
||||
# want access token to be there stored across runs
|
||||
|
||||
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
|
||||
|
||||
def get_credential(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> Optional[AuthCredential]:
|
||||
if not self.tool_context:
|
||||
return None
|
||||
|
||||
token_key = self.get_credential_key(auth_scheme, auth_credential)
|
||||
# TODO try not to use session state, this looks a hacky way, depend on
|
||||
# session implementation, we don't want session to persist the token,
|
||||
# meanwhile we want the token shared across runs.
|
||||
serialized_credential = self.tool_context.state.get(token_key)
|
||||
if not serialized_credential:
|
||||
return None
|
||||
return AuthCredential.model_validate(serialized_credential)
|
||||
|
||||
def store_credential(
|
||||
self,
|
||||
key: str,
|
||||
auth_credential: Optional[AuthCredential],
|
||||
):
|
||||
if self.tool_context:
|
||||
serializable_credential = jsonable_encoder(
|
||||
auth_credential, exclude_none=True
|
||||
)
|
||||
self.tool_context.state[key] = serializable_credential
|
||||
|
||||
def remove_credential(self, key: str):
|
||||
del self.tool_context.state[key]
|
||||
|
||||
|
||||
class ToolAuthHandler:
|
||||
"""Handles the preparation and exchange of authentication credentials for tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
credential_store: Optional["ToolContextCredentialStore"] = None,
|
||||
):
|
||||
self.tool_context = tool_context
|
||||
self.auth_scheme = (
|
||||
auth_scheme.model_copy(deep=True) if auth_scheme else None
|
||||
)
|
||||
self.auth_credential = (
|
||||
auth_credential.model_copy(deep=True) if auth_credential else None
|
||||
)
|
||||
self.credential_exchanger = (
|
||||
credential_exchanger or AutoAuthCredentialExchanger()
|
||||
)
|
||||
self.credential_store = credential_store
|
||||
self.should_store_credential = True
|
||||
|
||||
@classmethod
|
||||
def from_tool_context(
|
||||
cls,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
) -> "ToolAuthHandler":
|
||||
"""Creates a ToolAuthHandler instance from a ToolContext."""
|
||||
credential_store = ToolContextCredentialStore(tool_context)
|
||||
return cls(
|
||||
tool_context,
|
||||
auth_scheme,
|
||||
auth_credential,
|
||||
credential_exchanger,
|
||||
credential_store,
|
||||
)
|
||||
|
||||
def _handle_existing_credential(
|
||||
self,
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Checks for and returns an existing, exchanged credential."""
|
||||
if self.credential_store:
|
||||
existing_credential = self.credential_store.get_credential(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
if existing_credential:
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=existing_credential,
|
||||
)
|
||||
return None
|
||||
|
||||
def _exchange_credential(
|
||||
self, auth_credential: AuthCredential
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Handles an OpenID Connect authorization response."""
|
||||
|
||||
exchanged_credential = None
|
||||
try:
|
||||
exchanged_credential = self.credential_exchanger.exchange_credential(
|
||||
self.auth_scheme, auth_credential
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to exchange credential: %s", e)
|
||||
return exchanged_credential
|
||||
|
||||
def _store_credential(self, auth_credential: AuthCredential) -> None:
|
||||
"""stores the auth_credential."""
|
||||
|
||||
if self.credential_store:
|
||||
key = self.credential_store.get_credential_key(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
self.credential_store.store_credential(key, auth_credential)
|
||||
|
||||
def _reqeust_credential(self) -> None:
|
||||
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
|
||||
if self.auth_scheme.type_ in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
if not self.auth_credential or not self.auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
|
||||
"Please create AuthCredential using OAuth2Auth."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_id:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_id is missing."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_secret:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_secret is missing."
|
||||
)
|
||||
|
||||
self.tool_context.request_credential(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_auth_response(self) -> AuthCredential:
|
||||
return self.tool_context.get_auth_response(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
|
||||
def _request_credential(self, auth_config: AuthConfig):
|
||||
if not self.tool_context:
|
||||
return
|
||||
self.tool_context.request_credential(auth_config)
|
||||
|
||||
def prepare_auth_credentials(
|
||||
self,
|
||||
) -> AuthPreparationResult:
|
||||
"""Prepares authentication credentials, handling exchange and user interaction."""
|
||||
|
||||
# no auth is needed
|
||||
if not self.auth_scheme:
|
||||
return AuthPreparationResult(state="done")
|
||||
|
||||
# Check for existing credential.
|
||||
existing_result = self._handle_existing_credential()
|
||||
if existing_result:
|
||||
return existing_result
|
||||
|
||||
# fetch credential from adk framework
|
||||
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
|
||||
# multi-step exchange:
|
||||
# client_id , client_secret -> auth_uri -> auth_code -> access_token
|
||||
# -> bearer token
|
||||
# adk framework supports exchange access_token already
|
||||
fetched_credential = self._get_auth_response() or self.auth_credential
|
||||
|
||||
exchanged_credential = self._exchange_credential(fetched_credential)
|
||||
|
||||
if exchanged_credential:
|
||||
self._store_credential(exchanged_credential)
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=exchanged_credential,
|
||||
)
|
||||
else:
|
||||
self._reqeust_credential()
|
||||
return AuthPreparationResult(
|
||||
state="pending",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=self.auth_credential,
|
||||
)
|
||||
72
src/google/adk/tools/preload_memory_tool.py
Normal file
72
src/google/adk/tools/preload_memory_tool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class PreloadMemoryTool(BaseTool):
|
||||
"""A tool that preloads the memory for the current user."""
|
||||
|
||||
def __init__(self):
|
||||
# Name and description are not used because this tool only
|
||||
# changes llm_request.
|
||||
super().__init__(name='preload_memory', description='preload_memory')
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
parts = tool_context.user_content.parts
|
||||
if not parts or not parts[0].text:
|
||||
return
|
||||
query = parts[0].text
|
||||
response = tool_context.search_memory(query)
|
||||
if not response.memories:
|
||||
return
|
||||
memory_text = ''
|
||||
for memory in response.memories:
|
||||
time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat()
|
||||
memory_text += f'Time: {time_str}\n'
|
||||
for event in memory.events:
|
||||
# TODO: support multi-part content.
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].text
|
||||
):
|
||||
memory_text += f'{event.author}: {event.content.parts[0].text}\n'
|
||||
si = f"""The following content is from your previous conversations with the user.
|
||||
They may be useful for answering the user's current query.
|
||||
<PAST_CONVERSATIONS>
|
||||
{memory_text}
|
||||
</PAST_CONVERSATIONS>
|
||||
"""
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
|
||||
preload_memory_tool = PreloadMemoryTool()
|
||||
36
src/google/adk/tools/retrieval/__init__.py
Normal file
36
src/google/adk/tools/retrieval/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_retrieval_tool import BaseRetrievalTool
|
||||
from .files_retrieval import FilesRetrieval
|
||||
from .llama_index_retrieval import LlamaIndexRetrieval
|
||||
|
||||
__all__ = [
|
||||
'BaseRetrievalTool',
|
||||
'FilesRetrieval',
|
||||
'LlamaIndexRetrieval',
|
||||
]
|
||||
|
||||
try:
|
||||
from .vertex_ai_rag_retrieval import VertexAiRagRetrieval
|
||||
|
||||
__all__.append('VertexAiRagRetrieval')
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug(
|
||||
'The Vertex sdk is not installed. If you want to use the Vertex RAG with'
|
||||
' agents, please install it. If not, you can ignore this warning.'
|
||||
)
|
||||
37
src/google/adk/tools/retrieval/base_retrieval_tool.py
Normal file
37
src/google/adk/tools/retrieval/base_retrieval_tool.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..base_tool import BaseTool
|
||||
|
||||
|
||||
class BaseRetrievalTool(BaseTool):
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'query': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
description='The query to retrieve.',
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
33
src/google/adk/tools/retrieval/files_retrieval.py
Normal file
33
src/google/adk/tools/retrieval/files_retrieval.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Provides data for the agent."""
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
from .llama_index_retrieval import LlamaIndexRetrieval
|
||||
|
||||
|
||||
class FilesRetrieval(LlamaIndexRetrieval):
|
||||
|
||||
def __init__(self, *, name: str, description: str, input_dir: str):
|
||||
|
||||
self.input_dir = input_dir
|
||||
|
||||
print(f'Loading data from {input_dir}')
|
||||
retriever = VectorStoreIndex.from_documents(
|
||||
SimpleDirectoryReader(input_dir).load_data()
|
||||
).as_retriever()
|
||||
super().__init__(name=name, description=description, retriever=retriever)
|
||||
41
src/google/adk/tools/retrieval/llama_index_retrieval.py
Normal file
41
src/google/adk/tools/retrieval/llama_index_retrieval.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Provides data for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..tool_context import ToolContext
|
||||
from .base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
|
||||
|
||||
class LlamaIndexRetrieval(BaseRetrievalTool):
|
||||
|
||||
def __init__(self, *, name: str, description: str, retriever: BaseRetriever):
|
||||
super().__init__(name=name, description=description)
|
||||
self.retriever = retriever
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
return self.retriever.retrieve(args['query'])[0].text
|
||||
107
src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py
Normal file
107
src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A retrieval tool that uses Vertex AI RAG to retrieve data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
from vertexai.preview import rag
|
||||
|
||||
from ..tool_context import ToolContext
|
||||
from .base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAiRagRetrieval(BaseRetrievalTool):
|
||||
"""A retrieval tool that uses Vertex AI RAG (Retrieval-Augmented Generation) to retrieve data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
rag_corpora: list[str] = None,
|
||||
rag_resources: list[rag.RagResource] = None,
|
||||
similarity_top_k: int = None,
|
||||
vector_distance_threshold: float = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self.vertex_rag_store = types.VertexRagStore(
|
||||
rag_corpora=rag_corpora,
|
||||
rag_resources=rag_resources,
|
||||
similarity_top_k=similarity_top_k,
|
||||
vector_distance_threshold=vector_distance_threshold,
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
# Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
|
||||
if llm_request.model and llm_request.model.startswith('gemini-2'):
|
||||
llm_request.config = (
|
||||
types.GenerateContentConfig()
|
||||
if not llm_request.config
|
||||
else llm_request.config
|
||||
)
|
||||
llm_request.config.tools = (
|
||||
[] if not llm_request.config.tools else llm_request.config.tools
|
||||
)
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(
|
||||
retrieval=types.Retrieval(vertex_rag_store=self.vertex_rag_store)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Add the function declaration to the tools
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
|
||||
response = rag.retrieval_query(
|
||||
text=args['query'],
|
||||
rag_resources=self.vertex_rag_store.rag_resources,
|
||||
rag_corpora=self.vertex_rag_store.rag_corpora,
|
||||
similarity_top_k=self.vertex_rag_store.similarity_top_k,
|
||||
vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
|
||||
)
|
||||
|
||||
logging.debug('RAG raw response: %s', response)
|
||||
|
||||
return (
|
||||
f'No matching result found with the config: {self.vertex_rag_store}'
|
||||
if not response.contexts.contexts
|
||||
else [context.text for context in response.contexts.contexts]
|
||||
)
|
||||
90
src/google/adk/tools/tool_context.py
Normal file
90
src/google/adk/tools/tool_context.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..auth.auth_credential import AuthCredential
|
||||
from ..auth.auth_handler import AuthHandler
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event_actions import EventActions
|
||||
from ..memory.base_memory_service import SearchMemoryResponse
|
||||
|
||||
|
||||
class ToolContext(CallbackContext):
|
||||
"""The context of the tool.
|
||||
|
||||
This class provides the context for a tool invocation, including access to
|
||||
the invocation context, function call ID, event actions, and authentication
|
||||
response. It also provides methods for requesting credentials, retrieving
|
||||
authentication responses, listing artifacts, and searching memory.
|
||||
|
||||
Attributes:
|
||||
invocation_context: The invocation context of the tool.
|
||||
function_call_id: The function call id of the current tool call. This id was
|
||||
returned in the function call event from LLM to identify a function call.
|
||||
If LLM didn't return this id, ADK will assign one to it. This id is used
|
||||
to map function call response to the original function call.
|
||||
event_actions: The event actions of the current tool call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
*,
|
||||
function_call_id: Optional[str] = None,
|
||||
event_actions: Optional[EventActions] = None,
|
||||
):
|
||||
super().__init__(invocation_context, event_actions=event_actions)
|
||||
self.function_call_id = function_call_id
|
||||
|
||||
@property
|
||||
def actions(self) -> EventActions:
|
||||
return self._event_actions
|
||||
|
||||
def request_credential(self, auth_config: AuthConfig) -> None:
|
||||
if not self.function_call_id:
|
||||
raise ValueError('function_call_id is not set.')
|
||||
self._event_actions.requested_auth_configs[self.function_call_id] = (
|
||||
AuthHandler(auth_config).generate_auth_request()
|
||||
)
|
||||
|
||||
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
|
||||
return AuthHandler(auth_config).get_auth_response(self.state)
|
||||
|
||||
def list_artifacts(self) -> list[str]:
|
||||
"""Lists the filenames of the artifacts attached to the current session."""
|
||||
if self._invocation_context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
return self._invocation_context.artifact_service.list_artifact_keys(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
)
|
||||
|
||||
def search_memory(self, query: str) -> 'SearchMemoryResponse':
|
||||
"""Searches the memory of the current user."""
|
||||
if self._invocation_context.memory_service is None:
|
||||
raise ValueError('Memory service is not available.')
|
||||
return self._invocation_context.memory_service.search_memory(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
query=query,
|
||||
)
|
||||
46
src/google/adk/tools/toolbox_tool.py
Normal file
46
src/google/adk/tools/toolbox_tool.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from .langchain_tool import LangchainTool
|
||||
|
||||
|
||||
class ToolboxTool:
|
||||
"""A class that provides access to toolbox tools.
|
||||
|
||||
Example:
|
||||
```python
|
||||
toolbox = ToolboxTool("http://127.0.0.1:8080")
|
||||
tool = toolbox.get_tool("tool_name")
|
||||
toolset = toolbox.get_toolset("toolset_name")
|
||||
```
|
||||
"""
|
||||
|
||||
toolbox_client: Any
|
||||
"""The toolbox client."""
|
||||
|
||||
def __init__(self, url: str):
|
||||
from toolbox_langchain import ToolboxClient
|
||||
|
||||
self.toolbox_client = ToolboxClient(url)
|
||||
|
||||
def get_tool(self, tool_name: str) -> LangchainTool:
|
||||
tool = self.toolbox_client.load_tool(tool_name)
|
||||
return LangchainTool(tool)
|
||||
|
||||
def get_toolset(self, toolset_name: str) -> list[LangchainTool]:
|
||||
tools = self.toolbox_client.load_toolset(toolset_name)
|
||||
return [LangchainTool(tool) for tool in tools]
|
||||
21
src/google/adk/tools/transfer_to_agent_tool.py
Normal file
21
src/google/adk/tools/transfer_to_agent_tool.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
# TODO: make this internal, since user doesn't need to use this tool directly.
|
||||
def transfer_to_agent(agent_name: str, tool_context: ToolContext):
|
||||
"""Transfer the question to another agent."""
|
||||
tool_context.actions.transfer_to_agent = agent_name
|
||||
96
src/google/adk/tools/vertex_ai_search_tool.py
Normal file
96
src/google/adk/tools/vertex_ai_search_tool.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class VertexAiSearchTool(BaseTool):
|
||||
"""A built-in tool using Vertex AI Search.
|
||||
|
||||
Attributes:
|
||||
data_store_id: The Vertex AI search data store resource ID.
|
||||
search_engine_id: The Vertex AI search engine resource ID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_store_id: Optional[str] = None,
|
||||
search_engine_id: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the Vertex AI Search tool.
|
||||
|
||||
Args:
|
||||
data_store_id: The Vertex AI search data store resource ID in the format
|
||||
of
|
||||
"projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}".
|
||||
search_engine_id: The Vertex AI search engine resource ID in the format of
|
||||
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
|
||||
|
||||
Raises:
|
||||
ValueError: If both data_store_id and search_engine_id are not specified
|
||||
or both are specified.
|
||||
"""
|
||||
# Name and description are not used because this is a model built-in tool.
|
||||
super().__init__(name='vertex_ai_search', description='vertex_ai_search')
|
||||
if (data_store_id is None and search_engine_id is None) or (
|
||||
data_store_id is not None and search_engine_id is not None
|
||||
):
|
||||
raise ValueError(
|
||||
'Either data_store_id or search_engine_id must be specified.'
|
||||
)
|
||||
self.data_store_id = data_store_id
|
||||
self.search_engine_id = search_engine_id
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
if llm_request.model and llm_request.model.startswith('gemini-'):
|
||||
if llm_request.model.startswith('gemini-1') and llm_request.config.tools:
|
||||
raise ValueError(
|
||||
'Vertex AI search tool can not be used with other tools in Gemini'
|
||||
' 1.x.'
|
||||
)
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.tools = llm_request.config.tools or []
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(
|
||||
retrieval=types.Retrieval(
|
||||
vertex_ai_search=types.VertexAISearch(
|
||||
datastore=self.data_store_id, engine=self.search_engine_id
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Vertex AI search tool is not supported for model'
|
||||
f' {llm_request.model}'
|
||||
)
|
||||
Reference in New Issue
Block a user