structure saas with tools
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=g-bad-import-order
|
||||
from .base_tool import BaseTool
|
||||
|
||||
from ..auth.auth_tool import AuthToolArguments
|
||||
from .apihub_tool.apihub_toolset import APIHubToolset
|
||||
from .built_in_code_execution_tool import built_in_code_execution
|
||||
from .google_search_tool import google_search
|
||||
from .vertex_ai_search_tool import VertexAiSearchTool
|
||||
from .example_tool import ExampleTool
|
||||
from .exit_loop_tool import exit_loop
|
||||
from .function_tool import FunctionTool
|
||||
from .get_user_choice_tool import get_user_choice_tool as get_user_choice
|
||||
from .load_artifacts_tool import load_artifacts_tool as load_artifacts
|
||||
from .load_memory_tool import load_memory_tool as load_memory
|
||||
from .long_running_tool import LongRunningFunctionTool
|
||||
from .preload_memory_tool import preload_memory_tool as preload_memory
|
||||
from .tool_context import ToolContext
|
||||
from .transfer_to_agent_tool import transfer_to_agent
|
||||
|
||||
|
||||
__all__ = [
|
||||
'APIHubToolset',
|
||||
'AuthToolArguments',
|
||||
'BaseTool',
|
||||
'built_in_code_execution',
|
||||
'google_search',
|
||||
'VertexAiSearchTool',
|
||||
'ExampleTool',
|
||||
'exit_loop',
|
||||
'FunctionTool',
|
||||
'get_user_choice',
|
||||
'load_artifacts',
|
||||
'load_memory',
|
||||
'LongRunningFunctionTool',
|
||||
'preload_memory',
|
||||
'ToolContext',
|
||||
'transfer_to_agent',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Forked from google3/third_party/py/google/genai/_automatic_function_calling_util.py temporarily."""
|
||||
|
||||
import inspect
|
||||
from types import FunctionType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic import create_model
|
||||
from pydantic import fields as pydantic_fields
|
||||
|
||||
from . import function_parameter_parse_util
|
||||
|
||||
_py_type_2_schema_type = {
|
||||
'str': types.Type.STRING,
|
||||
'int': types.Type.INTEGER,
|
||||
'float': types.Type.NUMBER,
|
||||
'bool': types.Type.BOOLEAN,
|
||||
'string': types.Type.STRING,
|
||||
'integer': types.Type.INTEGER,
|
||||
'number': types.Type.NUMBER,
|
||||
'boolean': types.Type.BOOLEAN,
|
||||
'list': types.Type.ARRAY,
|
||||
'array': types.Type.ARRAY,
|
||||
'tuple': types.Type.ARRAY,
|
||||
'object': types.Type.OBJECT,
|
||||
'Dict': types.Type.OBJECT,
|
||||
'List': types.Type.ARRAY,
|
||||
'Tuple': types.Type.ARRAY,
|
||||
'Any': types.Type.TYPE_UNSPECIFIED,
|
||||
}
|
||||
|
||||
|
||||
def _get_fields_dict(func: Callable) -> Dict:
|
||||
param_signature = dict(inspect.signature(func).parameters)
|
||||
fields_dict = {
|
||||
name: (
|
||||
# 1. We infer the argument type here: use Any rather than None so
|
||||
# it will not try to auto-infer the type based on the default value.
|
||||
(
|
||||
param.annotation
|
||||
if param.annotation != inspect.Parameter.empty
|
||||
else Any
|
||||
),
|
||||
pydantic.Field(
|
||||
# 2. We do not support default values for now.
|
||||
default=(
|
||||
param.default
|
||||
if param.default != inspect.Parameter.empty
|
||||
# ! Need to use Undefined instead of None
|
||||
else pydantic_fields.PydanticUndefined
|
||||
),
|
||||
# 3. Do not support parameter description for now.
|
||||
description=None,
|
||||
),
|
||||
)
|
||||
for name, param in param_signature.items()
|
||||
# We do not support *args or **kwargs
|
||||
if param.kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
}
|
||||
return fields_dict
|
||||
|
||||
|
||||
def _annotate_nullable_fields(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
# for Optional[T], the pydantic schema is:
|
||||
# {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "anyOf": [
|
||||
# {
|
||||
# "type": "null"
|
||||
# },
|
||||
# {
|
||||
# "type": "T"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# }
|
||||
for type_ in property_schema.get('anyOf', []):
|
||||
if type_.get('type') == 'null':
|
||||
property_schema['nullable'] = True
|
||||
property_schema['anyOf'].remove(type_)
|
||||
break
|
||||
|
||||
|
||||
def _annotate_required_fields(schema: Dict):
|
||||
required = [
|
||||
field_name
|
||||
for field_name, field_schema in schema.get('properties', {}).items()
|
||||
if not field_schema.get('nullable') and 'default' not in field_schema
|
||||
]
|
||||
schema['required'] = required
|
||||
|
||||
|
||||
def _remove_any_of(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
union_types = property_schema.pop('anyOf', None)
|
||||
# Take the first non-null type.
|
||||
if union_types:
|
||||
for type_ in union_types:
|
||||
if type_.get('type') != 'null':
|
||||
property_schema.update(type_)
|
||||
|
||||
|
||||
def _remove_default(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('default', None)
|
||||
|
||||
|
||||
def _remove_nullable(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('nullable', None)
|
||||
|
||||
|
||||
def _remove_title(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
property_schema.pop('title', None)
|
||||
|
||||
|
||||
def _get_pydantic_schema(func: Callable) -> Dict:
|
||||
fields_dict = _get_fields_dict(func)
|
||||
if 'tool_context' in fields_dict.keys():
|
||||
fields_dict.pop('tool_context')
|
||||
return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
|
||||
|
||||
|
||||
def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
|
||||
_annotate_nullable_fields(schema)
|
||||
_annotate_required_fields(schema)
|
||||
if not vertexai:
|
||||
_remove_any_of(schema)
|
||||
_remove_default(schema)
|
||||
_remove_nullable(schema)
|
||||
_remove_title(schema)
|
||||
return schema
|
||||
|
||||
|
||||
def _map_pydantic_type_to_property_schema(property_schema: Dict):
|
||||
if 'type' in property_schema:
|
||||
property_schema['type'] = _py_type_2_schema_type.get(
|
||||
property_schema['type'], 'TYPE_UNSPECIFIED'
|
||||
)
|
||||
if property_schema['type'] == 'ARRAY':
|
||||
_map_pydantic_type_to_property_schema(property_schema['items'])
|
||||
for type_ in property_schema.get('anyOf', []):
|
||||
if 'type' in type_:
|
||||
type_['type'] = _py_type_2_schema_type.get(
|
||||
type_['type'], 'TYPE_UNSPECIFIED'
|
||||
)
|
||||
# TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
|
||||
# avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
|
||||
property_schema['type'] = type_['type']
|
||||
|
||||
|
||||
def _map_pydantic_type_to_schema_type(schema: Dict):
|
||||
for _, property_schema in schema.get('properties', {}).items():
|
||||
_map_pydantic_type_to_property_schema(property_schema)
|
||||
|
||||
|
||||
def _get_return_type(func: Callable) -> Any:
|
||||
return _py_type_2_schema_type.get(
|
||||
inspect.signature(func).return_annotation.__name__,
|
||||
inspect.signature(func).return_annotation.__name__,
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration(
|
||||
func: Union[Callable, BaseModel],
|
||||
ignore_params: Optional[list[str]] = None,
|
||||
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
|
||||
) -> types.FunctionDeclaration:
|
||||
signature = inspect.signature(func)
|
||||
should_update_signature = False
|
||||
new_func = None
|
||||
if not ignore_params:
|
||||
ignore_params = []
|
||||
for name, _ in signature.parameters.items():
|
||||
if name in ignore_params:
|
||||
should_update_signature = True
|
||||
break
|
||||
if should_update_signature:
|
||||
new_params = [
|
||||
param
|
||||
for name, param in signature.parameters.items()
|
||||
if name not in ignore_params
|
||||
]
|
||||
if isinstance(func, type):
|
||||
fields = {
|
||||
name: (param.annotation, param.default)
|
||||
for name, param in signature.parameters.items()
|
||||
if name not in ignore_params
|
||||
}
|
||||
new_func = create_model(func.__name__, **fields)
|
||||
else:
|
||||
new_sig = signature.replace(parameters=new_params)
|
||||
new_func = FunctionType(
|
||||
func.__code__,
|
||||
func.__globals__,
|
||||
func.__name__,
|
||||
func.__defaults__,
|
||||
func.__closure__,
|
||||
)
|
||||
new_func.__signature__ = new_sig
|
||||
|
||||
return (
|
||||
from_function_with_options(func, variant)
|
||||
if not should_update_signature
|
||||
else from_function_with_options(new_func, variant)
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_for_langchain(
|
||||
vertexai: bool, name, description, func, param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
param_pydantic_schema = _process_pydantic_schema(
|
||||
vertexai, {'properties': param_pydantic_schema}
|
||||
)['properties']
|
||||
param_copy = param_pydantic_schema.copy()
|
||||
required_fields = param_copy.pop('required', [])
|
||||
before_param_pydantic_schema = {
|
||||
'properties': param_copy,
|
||||
'required': required_fields,
|
||||
}
|
||||
return build_function_declaration_util(
|
||||
vertexai, name, description, func, before_param_pydantic_schema
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_for_params_for_crewai(
|
||||
vertexai: bool, name, description, func, param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
param_pydantic_schema = _process_pydantic_schema(
|
||||
vertexai, param_pydantic_schema
|
||||
)
|
||||
param_copy = param_pydantic_schema.copy()
|
||||
return build_function_declaration_util(
|
||||
vertexai, name, description, func, param_copy
|
||||
)
|
||||
|
||||
|
||||
def build_function_declaration_util(
|
||||
vertexai: bool, name, description, func, before_param_pydantic_schema
|
||||
) -> types.FunctionDeclaration:
|
||||
_map_pydantic_type_to_schema_type(before_param_pydantic_schema)
|
||||
properties = before_param_pydantic_schema.get('properties', {})
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type='OBJECT',
|
||||
properties=properties,
|
||||
)
|
||||
if properties
|
||||
else None,
|
||||
description=description,
|
||||
name=name,
|
||||
)
|
||||
if vertexai and isinstance(func, Callable):
|
||||
return_pydantic_schema = _get_return_type(func)
|
||||
function_declaration.response = types.Schema(
|
||||
type=return_pydantic_schema,
|
||||
)
|
||||
return function_declaration
|
||||
|
||||
|
||||
def from_function_with_options(
|
||||
func: Callable,
|
||||
variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
|
||||
) -> 'types.FunctionDeclaration':
|
||||
|
||||
supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
|
||||
if variant not in supported_variants:
|
||||
raise ValueError(
|
||||
f'Unsupported variant: {variant}. Supported variants are:'
|
||||
f' {", ".join(supported_variants)}'
|
||||
)
|
||||
|
||||
parameters_properties = {}
|
||||
for name, param in inspect.signature(func).parameters.items():
|
||||
if param.kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
):
|
||||
schema = function_parameter_parse_util._parse_schema_from_parameter(
|
||||
variant, param, func.__name__
|
||||
)
|
||||
parameters_properties[name] = schema
|
||||
declaration = types.FunctionDeclaration(
|
||||
name=func.__name__,
|
||||
description=func.__doc__,
|
||||
)
|
||||
if parameters_properties:
|
||||
declaration.parameters = types.Schema(
|
||||
type='OBJECT',
|
||||
properties=parameters_properties,
|
||||
)
|
||||
if variant == 'VERTEX_AI':
|
||||
declaration.parameters.required = (
|
||||
function_parameter_parse_util._get_required_fields(
|
||||
declaration.parameters
|
||||
)
|
||||
)
|
||||
if not variant == 'VERTEX_AI':
|
||||
return declaration
|
||||
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
if return_annotation is inspect._empty:
|
||||
return declaration
|
||||
|
||||
declaration.response = (
|
||||
function_parameter_parse_util._parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'return_value',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=return_annotation,
|
||||
),
|
||||
func.__name__,
|
||||
)
|
||||
)
|
||||
return declaration
|
||||
@@ -0,0 +1,175 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from . import _automatic_function_calling_util
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
class AgentTool(BaseTool):
|
||||
"""A tool that wraps an agent.
|
||||
|
||||
This tool allows an agent to be called as a tool within a larger application.
|
||||
The agent's input schema is used to define the tool's input parameters, and
|
||||
the agent's output is returned as the tool's result.
|
||||
|
||||
Attributes:
|
||||
agent: The agent to wrap.
|
||||
skip_summarization: Whether to skip summarization of the agent output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
||||
self.agent = agent
|
||||
self.skip_summarization: bool = skip_summarization
|
||||
|
||||
super().__init__(name=agent.name, description=agent.description)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def populate_name(cls, data: Any) -> Any:
|
||||
data['name'] = data['agent'].name
|
||||
return data
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
result = _automatic_function_calling_util.build_function_declaration(
|
||||
func=self.agent.input_schema, variant=self._api_variant
|
||||
)
|
||||
else:
|
||||
result = types.FunctionDeclaration(
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'request': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
},
|
||||
required=['request'],
|
||||
),
|
||||
description=self.agent.description,
|
||||
name=self.name,
|
||||
)
|
||||
result.name = self.name
|
||||
return result
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
if self.skip_summarization:
|
||||
tool_context.actions.skip_summarization = True
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
input_value = self.agent.input_schema.model_validate(args)
|
||||
else:
|
||||
input_value = args['request']
|
||||
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
|
||||
if isinstance(input_value, dict):
|
||||
input_value = self.agent.input_schema.model_validate(input_value)
|
||||
if not isinstance(input_value, self.agent.input_schema):
|
||||
raise ValueError(
|
||||
f'Input value {input_value} is not of type'
|
||||
f' `{self.agent.input_schema}`.'
|
||||
)
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=input_value.model_dump_json(exclude_none=True)
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[types.Part.from_text(text=input_value)],
|
||||
)
|
||||
runner = Runner(
|
||||
app_name=self.agent.name,
|
||||
agent=self.agent,
|
||||
# TODO(kech): Remove the access to the invocation context.
|
||||
# It seems we don't need re-use artifact_service if we forward below.
|
||||
artifact_service=tool_context._invocation_context.artifact_service,
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
session = runner.session_service.create_session(
|
||||
app_name=self.agent.name,
|
||||
user_id='tmp_user',
|
||||
state=tool_context.state.to_dict(),
|
||||
)
|
||||
|
||||
last_event = None
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
# Forward state delta to parent session.
|
||||
if event.actions.state_delta:
|
||||
tool_context.state.update(event.actions.state_delta)
|
||||
last_event = event
|
||||
|
||||
if runner.artifact_service:
|
||||
# Forward all artifacts to parent session.
|
||||
for artifact_name in runner.artifact_service.list_artifact_keys(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
):
|
||||
if artifact := runner.artifact_service.load_artifact(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
filename=artifact_name,
|
||||
):
|
||||
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
|
||||
|
||||
if (
|
||||
not last_event
|
||||
or not last_event.content
|
||||
or not last_event.content.parts
|
||||
or not last_event.content.parts[0].text
|
||||
):
|
||||
return ''
|
||||
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
|
||||
tool_result = self.agent.output_schema.model_validate_json(
|
||||
last_event.content.parts[0].text
|
||||
).model_dump(exclude_none=True)
|
||||
else:
|
||||
tool_result = last_event.content.parts[0].text
|
||||
return tool_result
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .apihub_toolset import APIHubToolset
|
||||
|
||||
__all__ = [
|
||||
'APIHubToolset',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..openapi_tool.common.common import to_snake_case
|
||||
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from .clients.apihub_client import APIHubClient
|
||||
|
||||
|
||||
class APIHubToolset:
|
||||
"""APIHubTool generates tools from a given API Hub resource.
|
||||
|
||||
Examples:
|
||||
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Parameters for fetching API Hub resource
|
||||
apihub_resource_name: str,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
# Parameters for the toolset itself
|
||||
name: str = '',
|
||||
description: str = '',
|
||||
# Parameters for generating tools
|
||||
lazy_load_spec=False,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
# Optionally, you can provide a custom API Hub client
|
||||
apihub_client: Optional[APIHubClient] = None,
|
||||
):
|
||||
"""Initializes the APIHubTool with the given parameters.
|
||||
|
||||
Examples:
|
||||
```
|
||||
apihub_toolset = APIHubToolset(
|
||||
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
|
||||
service_account_json="...",
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=apihub_toolset.get_tools())
|
||||
|
||||
# Get a specific tool
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
apihub_toolset.get_tool('my_tool'),
|
||||
])
|
||||
```
|
||||
|
||||
**apihub_resource_name** is the resource name from API Hub. It must include
|
||||
API name, and can optionally include API version and spec name.
|
||||
- If apihub_resource_name includes a spec resource name, the content of that
|
||||
spec will be used for generating the tools.
|
||||
- If apihub_resource_name includes only an api or a version name, the
|
||||
first spec of the first version of that API will be used.
|
||||
|
||||
Example:
|
||||
* projects/xxx/locations/us-central1/apis/apiname/...
|
||||
* https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx
|
||||
|
||||
Args:
|
||||
apihub_resource_name: The resource name of the API in API Hub.
|
||||
Example: `projects/test-project/locations/us-central1/apis/test-api`.
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
auth print-access-token`. Used for fetching API Specs from API Hub.
|
||||
service_account_json: The service account config as a json string.
|
||||
Required if not using default service credential. It is used for
|
||||
creating the API Hub client and fetching the API Specs from API Hub.
|
||||
apihub_client: Optional custom API Hub client.
|
||||
name: Name of the toolset. Optional.
|
||||
description: Description of the toolset. Optional.
|
||||
auth_scheme: Auth scheme that applies to all the tool in the toolset.
|
||||
auth_credential: Auth credential that applies to all the tool in the
|
||||
toolset.
|
||||
lazy_load_spec: If True, the spec will be loaded lazily when needed.
|
||||
Otherwise, the spec will be loaded immediately and the tools will be
|
||||
generated during initialization.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.apihub_resource_name = apihub_resource_name
|
||||
self.lazy_load_spec = lazy_load_spec
|
||||
self.apihub_client = apihub_client or APIHubClient(
|
||||
access_token=access_token,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
if not self.lazy_load_spec:
|
||||
self._prepare_tools()
|
||||
|
||||
def get_tool(self, name: str) -> Optional[RestApiTool]:
|
||||
"""Retrieves a specific tool by its name.
|
||||
|
||||
Example:
|
||||
```
|
||||
apihub_tool = apihub_toolset.get_tool('my_tool')
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name of the tool to retrieve.
|
||||
|
||||
Returns:
|
||||
The tool with the given name, or None if no such tool exists.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return self.generated_tools[name] if name in self.generated_tools else None
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Retrieves all available tools.
|
||||
|
||||
Returns:
|
||||
A list of all available RestApiTool objects.
|
||||
"""
|
||||
if not self._are_tools_ready():
|
||||
self._prepare_tools()
|
||||
|
||||
return list(self.generated_tools.values())
|
||||
|
||||
def _are_tools_ready(self) -> bool:
|
||||
return not self.lazy_load_spec or self.generated_tools
|
||||
|
||||
def _prepare_tools(self) -> str:
|
||||
"""Fetches the spec from API Hub and generates the tools.
|
||||
|
||||
Returns:
|
||||
True if the tools are ready, False otherwise.
|
||||
"""
|
||||
# For each API, get the first version and the first spec of that version.
|
||||
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
tools = self._parse_spec_to_tools(spec)
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
|
||||
"""Parses the spec string to a list of RestApiTool.
|
||||
|
||||
Args:
|
||||
spec_str: The spec string to parse.
|
||||
|
||||
Returns:
|
||||
A list of RestApiTool objects.
|
||||
"""
|
||||
spec_dict = yaml.safe_load(spec_str)
|
||||
if not spec_dict:
|
||||
return []
|
||||
|
||||
self.name = self.name or to_snake_case(
|
||||
spec_dict.get('info', {}).get('title', 'unnamed')
|
||||
)
|
||||
self.description = self.description or spec_dict.get('info', {}).get(
|
||||
'description', ''
|
||||
)
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=self.auth_credential,
|
||||
auth_scheme=self.auth_scheme,
|
||||
).get_tools()
|
||||
return tools
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,332 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class BaseAPIHubClient(ABC):
|
||||
"""Base class for API Hub clients."""
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_content(self, resource_name: str) -> str:
|
||||
"""From a given resource name, get the soec in the API Hub."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class APIHubClient(BaseAPIHubClient):
|
||||
"""Client for interacting with the API Hub service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
access_token: Optional[str] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the APIHubClient.
|
||||
|
||||
You must set either access_token or service_account_json. This
|
||||
credential is used for sending request to API Hub API.
|
||||
|
||||
Args:
|
||||
access_token: Google Access token. Generate with gcloud cli `gcloud auth
|
||||
print-access-token`. Useful for local testing.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential.
|
||||
"""
|
||||
self.root_url = "https://apihub.googleapis.com/v1"
|
||||
self.credential_cache = None
|
||||
self.access_token, self.service_account = None, None
|
||||
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
elif service_account_json:
|
||||
self.service_account = service_account_json
|
||||
|
||||
def get_spec_content(self, path: str) -> str:
|
||||
"""From a given path, get the first spec available in the API Hub.
|
||||
|
||||
- If path includes /apis/apiname, get the first spec of that API
|
||||
- If path includes /apis/apiname/versions/versionname, get the first spec
|
||||
of that API Version
|
||||
- If path includes /apis/apiname/versions/versionname/specs/specname, return
|
||||
that spec
|
||||
|
||||
Path can be resource name (projects/xxx/locations/us-central1/apis/apiname),
|
||||
and URL from the UI
|
||||
(https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx)
|
||||
|
||||
Args:
|
||||
path: The path to the API, API Version, or API Spec.
|
||||
|
||||
Returns:
|
||||
The content of the first spec available in the API Hub.
|
||||
"""
|
||||
apihub_resource_name, api_version_resource_name, api_spec_resource_name = (
|
||||
self._extract_resource_name(path)
|
||||
)
|
||||
|
||||
if apihub_resource_name and not api_version_resource_name:
|
||||
api = self.get_api(apihub_resource_name)
|
||||
versions = api.get("versions", [])
|
||||
if not versions:
|
||||
raise ValueError(
|
||||
f"No versions found in API Hub resource: {apihub_resource_name}"
|
||||
)
|
||||
api_version_resource_name = versions[0]
|
||||
|
||||
if api_version_resource_name and not api_spec_resource_name:
|
||||
api_version = self.get_api_version(api_version_resource_name)
|
||||
spec_resource_names = api_version.get("specs", [])
|
||||
if not spec_resource_names:
|
||||
raise ValueError(
|
||||
f"No specs found in API Hub version: {api_version_resource_name}"
|
||||
)
|
||||
api_spec_resource_name = spec_resource_names[0]
|
||||
|
||||
if api_spec_resource_name:
|
||||
spec_content = self._fetch_spec(api_spec_resource_name)
|
||||
return spec_content
|
||||
|
||||
raise ValueError("No API Hub resource found in path: {path}")
|
||||
|
||||
def list_apis(self, project: str, location: str) -> List[Dict[str, Any]]:
|
||||
"""Lists all APIs in the specified project and location.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project name.
|
||||
location: The location of the API Hub resources (e.g., 'us-central1').
|
||||
|
||||
Returns:
|
||||
A list of API dictionaries, or an empty list if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/projects/{project}/locations/{location}/apis"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json().get("apis", [])
|
||||
return apis
|
||||
|
||||
def get_api(self, api_resource_name: str) -> Dict[str, Any]:
|
||||
"""Get API detail by API name.
|
||||
|
||||
Args:
|
||||
api_resource_name: Resource name of this API, like
|
||||
projects/xxx/locations/us-central1/apis/apiname
|
||||
|
||||
Returns:
|
||||
An API and details in a dict.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_resource_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
apis = response.json()
|
||||
return apis
|
||||
|
||||
def get_api_version(self, api_version_name: str) -> Dict[str, Any]:
|
||||
"""Gets details of a specific API version.
|
||||
|
||||
Args:
|
||||
api_version_name: The resource name of the API version.
|
||||
|
||||
Returns:
|
||||
The API version details as a dictionary, or an empty dictionary if an
|
||||
error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_version_name}"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _fetch_spec(self, api_spec_resource_name: str) -> str:
|
||||
"""Retrieves the content of a specific API specification.
|
||||
|
||||
Args:
|
||||
api_spec_resource_name: The resource name of the API spec.
|
||||
|
||||
Returns:
|
||||
The decoded content of the specification as a string, or an empty string
|
||||
if an error occurs.
|
||||
"""
|
||||
url = f"{self.root_url}/{api_spec_resource_name}:contents"
|
||||
headers = {
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
content_base64 = response.json().get("contents", "")
|
||||
if content_base64:
|
||||
content_decoded = base64.b64decode(content_base64).decode("utf-8")
|
||||
return content_decoded
|
||||
else:
|
||||
return ""
|
||||
|
||||
def _extract_resource_name(self, url_or_path: str) -> Tuple[str, str, str]:
|
||||
"""Extracts the resource names of an API, API Version, and API Spec from a given URL or path.
|
||||
|
||||
Args:
|
||||
url_or_path: The URL (UI or resource) or path string.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the resource names:
|
||||
{
|
||||
"api_resource_name": "projects/*/locations/*/apis/*",
|
||||
"api_version_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*",
|
||||
"api_spec_resource_name":
|
||||
"projects/*/locations/*/apis/*/versions/*/specs/*"
|
||||
}
|
||||
or raises ValueError if extraction fails.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL or path is invalid or if required components
|
||||
(project, location, api) are missing.
|
||||
"""
|
||||
|
||||
query_params = None
|
||||
try:
|
||||
parsed_url = urlparse(url_or_path)
|
||||
path = parsed_url.path
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
# This is a path from UI. Remove unnecessary prefix.
|
||||
if "api-hub/" in path:
|
||||
path = path.split("api-hub")[1]
|
||||
except Exception:
|
||||
path = url_or_path
|
||||
|
||||
path_segments = [segment for segment in path.split("/") if segment]
|
||||
|
||||
project = None
|
||||
location = None
|
||||
api_id = None
|
||||
version_id = None
|
||||
spec_id = None
|
||||
|
||||
if "projects" in path_segments:
|
||||
project_index = path_segments.index("projects")
|
||||
if project_index + 1 < len(path_segments):
|
||||
project = path_segments[project_index + 1]
|
||||
elif query_params and "project" in query_params:
|
||||
project = query_params["project"][0]
|
||||
|
||||
if not project:
|
||||
raise ValueError(
|
||||
"Project ID not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/projects/PROJECT_ID' in the path or 'project=PROJECT_ID' query"
|
||||
" param in the input."
|
||||
)
|
||||
|
||||
if "locations" in path_segments:
|
||||
location_index = path_segments.index("locations")
|
||||
if location_index + 1 < len(path_segments):
|
||||
location = path_segments[location_index + 1]
|
||||
if not location:
|
||||
raise ValueError(
|
||||
"Location not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/location/LOCATION_ID' in the path."
|
||||
)
|
||||
|
||||
if "apis" in path_segments:
|
||||
api_index = path_segments.index("apis")
|
||||
if api_index + 1 < len(path_segments):
|
||||
api_id = path_segments[api_index + 1]
|
||||
if not api_id:
|
||||
raise ValueError(
|
||||
"API id not found in URL or path in APIHubClient. Input path is"
|
||||
f" '{url_or_path}'. Please make sure there is either"
|
||||
" '/apis/API_ID' in the path."
|
||||
)
|
||||
if "versions" in path_segments:
|
||||
version_index = path_segments.index("versions")
|
||||
if version_index + 1 < len(path_segments):
|
||||
version_id = path_segments[version_index + 1]
|
||||
|
||||
if "specs" in path_segments:
|
||||
spec_index = path_segments.index("specs")
|
||||
if spec_index + 1 < len(path_segments):
|
||||
spec_id = path_segments[spec_index + 1]
|
||||
|
||||
api_resource_name = f"projects/{project}/locations/{location}/apis/{api_id}"
|
||||
api_version_resource_name = (
|
||||
f"{api_resource_name}/versions/{version_id}" if version_id else None
|
||||
)
|
||||
api_spec_resource_name = (
|
||||
f"{api_version_resource_name}/specs/{spec_id}"
|
||||
if version_id and spec_id
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
api_resource_name,
|
||||
api_version_resource_name,
|
||||
api_spec_resource_name,
|
||||
)
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.access_token:
|
||||
return self.access_token
|
||||
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
import google.auth.transport.requests
|
||||
from google.cloud import secretmanager
|
||||
from google.oauth2 import service_account
|
||||
|
||||
|
||||
class SecretManagerClient:
|
||||
"""A client for interacting with Google Cloud Secret Manager.
|
||||
|
||||
This class provides a simplified interface for retrieving secrets from
|
||||
Secret Manager, handling authentication using either a service account
|
||||
JSON keyfile (passed as a string) or a pre-existing authorization token.
|
||||
|
||||
Attributes:
|
||||
_credentials: Google Cloud credentials object (ServiceAccountCredentials
|
||||
or Credentials).
|
||||
_client: Secret Manager client instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_account_json: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the SecretManagerClient.
|
||||
|
||||
Args:
|
||||
service_account_json: The content of a service account JSON keyfile (as
|
||||
a string), not the file path. Must be valid JSON.
|
||||
auth_token: An existing Google Cloud authorization token.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `service_account_json` nor `auth_token` is
|
||||
provided,
|
||||
or if both are provided. Also raised if the service_account_json
|
||||
is not valid JSON.
|
||||
google.auth.exceptions.GoogleAuthError: If authentication fails.
|
||||
"""
|
||||
if service_account_json:
|
||||
try:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(service_account_json)
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid service account JSON: {e}") from e
|
||||
elif auth_token:
|
||||
credentials = google.auth.credentials.Credentials(
|
||||
token=auth_token,
|
||||
refresh_token=None,
|
||||
token_uri=None,
|
||||
client_id=None,
|
||||
client_secret=None,
|
||||
)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"'service_account_json' or 'auth_token' are both missing, and"
|
||||
f" error occurred while trying to use default credentials: {e}"
|
||||
) from e
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Must provide either 'service_account_json' or 'auth_token', not both"
|
||||
" or neither."
|
||||
)
|
||||
|
||||
self._credentials = credentials
|
||||
self._client = secretmanager.SecretManagerServiceClient(
|
||||
credentials=self._credentials
|
||||
)
|
||||
|
||||
def get_secret(self, resource_name: str) -> str:
|
||||
"""Retrieves a secret from Google Cloud Secret Manager.
|
||||
|
||||
Args:
|
||||
resource_name: The full resource name of the secret, in the format
|
||||
"projects/*/secrets/*/versions/*". Usually you want the "latest"
|
||||
version, e.g.,
|
||||
"projects/my-project/secrets/my-secret/versions/latest".
|
||||
|
||||
Returns:
|
||||
The secret payload as a string.
|
||||
|
||||
Raises:
|
||||
google.api_core.exceptions.GoogleAPIError: If the Secret Manager API
|
||||
returns an error (e.g., secret not found, permission denied).
|
||||
Exception: For other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
response = self._client.access_secret_version(name=resource_name)
|
||||
return response.payload.data.decode("UTF-8")
|
||||
except Exception as e:
|
||||
raise e # Re-raise the exception to allow for handling by the caller
|
||||
# Consider logging the exception here before re-raising.
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .application_integration_toolset import ApplicationIntegrationToolset
|
||||
|
||||
__all__ = [
|
||||
'ApplicationIntegrationToolset',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,230 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_credential import AuthCredentialTypes
|
||||
from ...auth.auth_credential import ServiceAccount
|
||||
from ...auth.auth_credential import ServiceAccountCredential
|
||||
|
||||
|
||||
# TODO(cheliu): Apply a common toolset interface
|
||||
class ApplicationIntegrationToolset:
|
||||
"""ApplicationIntegrationToolset generates tools from a given Application
|
||||
|
||||
Integration or Integration Connector resource.
|
||||
Example Usage:
|
||||
```
|
||||
# Get all available tools for an integration with api trigger
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
integration="test-integration",
|
||||
trigger="api_trigger/test_trigger",
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools for a connection using entity operations and
|
||||
# actions
|
||||
# Note: Find the list of supported entity operations and actions for a
|
||||
connection
|
||||
# using integration connector apis:
|
||||
#
|
||||
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
connection="test-connection",
|
||||
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
|
||||
#empty list for actions means all operations on the entity are supported
|
||||
actions=["action1"],
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
*application_integration_toolset.get_tools(),
|
||||
])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
integration: Optional[str] = None,
|
||||
trigger: Optional[str] = None,
|
||||
connection: Optional[str] = None,
|
||||
entity_operations: Optional[str] = None,
|
||||
actions: Optional[str] = None,
|
||||
# Optional parameter for the toolset. This is prepended to the generated
|
||||
# tool/python function name.
|
||||
tool_name: Optional[str] = "",
|
||||
# Optional parameter for the toolset. This is appended to the generated
|
||||
# tool/python function description.
|
||||
tool_instructions: Optional[str] = "",
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationToolset.
|
||||
|
||||
Example Usage:
|
||||
```
|
||||
# Get all available tools for an integration with api trigger
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
integration="test-integration",
|
||||
trigger="api_trigger/test_trigger",
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools for a connection using entity operations and
|
||||
# actions
|
||||
# Note: Find the list of supported entity operations and actions for a
|
||||
connection
|
||||
# using integration connector apis:
|
||||
#
|
||||
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
|
||||
application_integration_toolset = ApplicationIntegrationToolset(
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
connection="test-connection",
|
||||
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
|
||||
#empty list for actions means all operations on the entity are supported
|
||||
actions=["action1"],
|
||||
service_account_credentials={...},
|
||||
)
|
||||
|
||||
# Get all available tools
|
||||
agent = LlmAgent(tools=[
|
||||
...
|
||||
*application_integration_toolset.get_tools(),
|
||||
])
|
||||
```
|
||||
|
||||
Args:
|
||||
project: The GCP project ID.
|
||||
location: The GCP location.
|
||||
integration: The integration name.
|
||||
trigger: The trigger name.
|
||||
connection: The connection name.
|
||||
entity_operations: The entity operations supported by the connection.
|
||||
actions: The actions supported by the connection.
|
||||
tool_name: The name of the tool.
|
||||
tool_instructions: The instructions for the tool.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
the Application Integration or Integration Connector resource.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither integration and trigger nor connection and
|
||||
(entity_operations or actions) is provided.
|
||||
Exception: If there is an error during the initialization of the
|
||||
integration or connection client.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.integration = integration
|
||||
self.trigger = trigger
|
||||
self.connection = connection
|
||||
self.entity_operations = entity_operations
|
||||
self.actions = actions
|
||||
self.tool_name = tool_name
|
||||
self.tool_instructions = tool_instructions
|
||||
self.service_account_json = service_account_json
|
||||
self.generated_tools: Dict[str, RestApiTool] = {}
|
||||
|
||||
integration_client = IntegrationClient(
|
||||
project,
|
||||
location,
|
||||
integration,
|
||||
trigger,
|
||||
connection,
|
||||
entity_operations,
|
||||
actions,
|
||||
service_account_json,
|
||||
)
|
||||
if integration and trigger:
|
||||
spec = integration_client.get_openapi_spec_for_integration()
|
||||
elif connection and (entity_operations or actions):
|
||||
connections_client = ConnectionsClient(
|
||||
project, location, connection, service_account_json
|
||||
)
|
||||
connection_details = connections_client.get_connection_details()
|
||||
tool_instructions += (
|
||||
"ALWAYS use serviceName = "
|
||||
+ connection_details["serviceName"]
|
||||
+ ", host = "
|
||||
+ connection_details["host"]
|
||||
+ " and the connection name = "
|
||||
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
||||
" using this tool"
|
||||
+ ". DONOT ask the user for these values as you already have those."
|
||||
)
|
||||
spec = integration_client.get_openapi_spec_for_connection(
|
||||
tool_name,
|
||||
tool_instructions,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either (integration and trigger) or (connection and"
|
||||
" (entity_operations or actions)) should be provided."
|
||||
)
|
||||
self._parse_spec_to_tools(spec)
|
||||
|
||||
def _parse_spec_to_tools(self, spec_dict):
|
||||
"""Parses the spec dict to a list of RestApiTool."""
|
||||
if self.service_account_json:
|
||||
sa_credential = ServiceAccountCredential.model_validate_json(
|
||||
self.service_account_json
|
||||
)
|
||||
service_account = ServiceAccount(
|
||||
service_account_credential=sa_credential,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
auth_scheme, auth_credential = service_account_scheme_credential(
|
||||
config=service_account
|
||||
)
|
||||
else:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
tools = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
auth_credential=auth_credential,
|
||||
auth_scheme=auth_scheme,
|
||||
).get_tools()
|
||||
for tool in tools:
|
||||
self.generated_tools[tool.name] = tool
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
return list(self.generated_tools.values())
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,903 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class ConnectionsClient:
|
||||
"""Utility class for interacting with Google Cloud Connectors API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
connection: str,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ConnectionsClient.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., us-central1).
|
||||
connection: The connection name.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
connection details.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.connection = connection
|
||||
self.connector_url = "https://connectors.googleapis.com"
|
||||
self.service_account_json = service_account_json
|
||||
self.credential_cache = None
|
||||
|
||||
def get_connection_details(self) -> Dict[str, Any]:
|
||||
"""Retrieves service details (service name and host) for a given connection.
|
||||
|
||||
Also returns if auth override is enabled for the connection.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (service_name, host).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}?view=BASIC"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
connection_data = response.json()
|
||||
service_name = connection_data.get("serviceDirectory", "")
|
||||
host = connection_data.get("host", "")
|
||||
if host:
|
||||
service_name = connection_data.get("tlsServiceDirectory", "")
|
||||
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
||||
return {
|
||||
"serviceName": service_name,
|
||||
"host": host,
|
||||
"authOverrideEnabled": auth_override_enabled,
|
||||
}
|
||||
|
||||
def get_entity_schema_and_operations(
|
||||
self, entity: str
|
||||
) -> Tuple[Dict[str, Any], List[str]]:
|
||||
"""Retrieves the JSON schema for a given entity in a connection.
|
||||
|
||||
Args:
|
||||
entity (str): The entity name.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (schema, operations).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getEntityType?entityId={entity}"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
operation_id = response.json().get("name")
|
||||
|
||||
if not operation_id:
|
||||
raise ValueError(
|
||||
f"Failed to get entity schema and operations for entity: {entity}"
|
||||
)
|
||||
|
||||
operation_response = self._poll_operation(operation_id)
|
||||
|
||||
schema = operation_response.get("response", {}).get("jsonSchema", {})
|
||||
operations = operation_response.get("response", {}).get("operations", [])
|
||||
return schema, operations
|
||||
|
||||
def get_action_schema(self, action: str) -> Dict[str, Any]:
|
||||
"""Retrieves the input and output JSON schema for a given action in a connection.
|
||||
|
||||
Args:
|
||||
action (str): The action name.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (input_schema, output_schema).
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getAction?actionId={action}"
|
||||
|
||||
response = self._execute_api_call(url)
|
||||
|
||||
operation_id = response.json().get("name")
|
||||
|
||||
if not operation_id:
|
||||
raise ValueError(f"Failed to get action schema for action: {action}")
|
||||
|
||||
operation_response = self._poll_operation(operation_id)
|
||||
|
||||
input_schema = operation_response.get("response", {}).get(
|
||||
"inputJsonSchema", {}
|
||||
)
|
||||
output_schema = operation_response.get("response", {}).get(
|
||||
"outputJsonSchema", {}
|
||||
)
|
||||
description = operation_response.get("response", {}).get("description", "")
|
||||
display_name = operation_response.get("response", {}).get("displayName", "")
|
||||
return {
|
||||
"inputSchema": input_schema,
|
||||
"outputSchema": output_schema,
|
||||
"description": description,
|
||||
"displayName": display_name,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_connector_base_spec() -> Dict[str, Any]:
|
||||
return {
|
||||
"openapi": "3.0.1",
|
||||
"info": {
|
||||
"title": "ExecuteConnection",
|
||||
"description": "This tool can execute a query on connection",
|
||||
"version": "4",
|
||||
},
|
||||
"servers": [{"url": "https://integrations.googleapis.com"}],
|
||||
"security": [
|
||||
{"google_auth": ["https://www.googleapis.com/auth/cloud-platform"]}
|
||||
],
|
||||
"paths": {},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"default": "LIST_ENTITIES",
|
||||
"description": (
|
||||
"Operation to execute. Possible values are"
|
||||
" LIST_ENTITIES, GET_ENTITY, CREATE_ENTITY,"
|
||||
" UPDATE_ENTITY, DELETE_ENTITY in case of entities."
|
||||
" EXECUTE_ACTION in case of actions. and EXECUTE_QUERY"
|
||||
" in case of custom queries."
|
||||
),
|
||||
},
|
||||
"entityId": {
|
||||
"type": "string",
|
||||
"description": "Name of the entity",
|
||||
},
|
||||
"connectorInputPayload": {"type": "object"},
|
||||
"filterClause": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "WHERE clause in SQL query",
|
||||
},
|
||||
"pageSize": {
|
||||
"type": "integer",
|
||||
"default": 50,
|
||||
"description": (
|
||||
"Number of entities to return in the response"
|
||||
),
|
||||
},
|
||||
"pageToken": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": (
|
||||
"Page token to return the next page of entities"
|
||||
),
|
||||
},
|
||||
"connectionName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": (
|
||||
"Connection resource name to run the query for"
|
||||
),
|
||||
},
|
||||
"serviceName": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Service directory for the connection",
|
||||
},
|
||||
"host": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Host name incase of tls service directory",
|
||||
},
|
||||
"entity": {
|
||||
"type": "string",
|
||||
"default": "Issues",
|
||||
"description": "Entity to run the query for",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"default": "ExecuteCustomQuery",
|
||||
"description": "Action to run the query for",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "Custom Query to execute on the connection",
|
||||
},
|
||||
"dynamicAuthConfig": {
|
||||
"type": "object",
|
||||
"default": {},
|
||||
"description": "Dynamic auth config for the connection",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"default": 120,
|
||||
"description": (
|
||||
"Timeout in seconds for execution of custom query"
|
||||
),
|
||||
},
|
||||
"connectorOutputPayload": {"type": "object"},
|
||||
"nextPageToken": {"type": "string"},
|
||||
"execute-connector_Response": {
|
||||
"required": ["connectorOutputPayload"],
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connectorOutputPayload": {
|
||||
"$ref": (
|
||||
"#/components/schemas/connectorOutputPayload"
|
||||
)
|
||||
},
|
||||
"nextPageToken": {
|
||||
"$ref": "#/components/schemas/nextPageToken"
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"securitySchemes": {
|
||||
"google_auth": {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"implicit": {
|
||||
"authorizationUrl": (
|
||||
"https://accounts.google.com/o/oauth2/auth"
|
||||
),
|
||||
"scopes": {
|
||||
"https://www.googleapis.com/auth/cloud-platform": (
|
||||
"Auth for google cloud services"
|
||||
)
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_action_operation(
|
||||
action: str,
|
||||
operation: str,
|
||||
action_display_name: str,
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
description = (
|
||||
f"Use this tool with" f' action = "{action}" and'
|
||||
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
if operation == "EXECUTE_QUERY":
|
||||
description = (
|
||||
(f"Use this tool with" f' action = "{action}" and')
|
||||
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
||||
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
||||
" different value otherwise. If user provides a query in natural"
|
||||
" language, convert it to SQL query and then execute it using the"
|
||||
" tool."
|
||||
)
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"{action_display_name}",
|
||||
"description": f"{description} {tool_instructions}",
|
||||
"operationId": f"{tool_name}_{action_display_name}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/{action_display_name}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/{action_display_name}_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_operation(
|
||||
entity: str,
|
||||
schema_as_string: str = "",
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"List {entity}",
|
||||
"description": (
|
||||
f"Returns all entities of type {entity}. Use this tool with"
|
||||
+ f' entity = "{entity}" and'
|
||||
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
|
||||
" from"
|
||||
+ ' user. Always use ""'
|
||||
+ ' as filter clause and ""'
|
||||
+ " as page token and 50 as page size until user specifies a"
|
||||
" different value otherwise. Use single quotes for strings in"
|
||||
f" filter clause. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_list_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/list_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"description": (
|
||||
f"Returns a list of {entity} of json"
|
||||
f" schema: {schema_as_string}"
|
||||
),
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_operation(
|
||||
entity: str,
|
||||
schema_as_string: str = "",
|
||||
tool_name: str = "",
|
||||
tool_instructions: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Get {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Returns the details of the {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_get_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": f"#/components/schemas/get_{entity}_Request"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"description": (
|
||||
f"Returns {entity} of json schema:"
|
||||
f" {schema_as_string}"
|
||||
),
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Create {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Creates a new entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Follow the schema of the entity provided in the"
|
||||
f" instructions to create {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_create_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/create_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Update {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Updates an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
+ " user. Use entityId to uniquely identify the entity to"
|
||||
" update. Follow the schema of the entity provided in the"
|
||||
f" instructions to update {entity}. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_update_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/update_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_operation(
|
||||
entity: str, tool_name: str = "", tool_instructions: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"post": {
|
||||
"summary": f"Delete {entity}",
|
||||
"description": (
|
||||
(
|
||||
f"Deletes an entity of type {entity}. Use this tool with"
|
||||
f' entity = "{entity}" and'
|
||||
)
|
||||
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
||||
" from"
|
||||
f" user. {tool_instructions}"
|
||||
),
|
||||
"operationId": f"{tool_name}_delete_{entity}",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
f"#/components/schemas/delete_{entity}_Request"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"#/components/schemas/execute-connector_Response"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_operation_request(entity: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"connectorInputPayload",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
|
||||
},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_operation_request(entity: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"connectorInputPayload",
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{entity}"
|
||||
},
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def delete_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"entityId",
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"entityId": {"$ref": "#/components/schemas/entityId"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def list_operation_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"entity",
|
||||
],
|
||||
"properties": {
|
||||
"filterClause": {"$ref": "#/components/schemas/filterClause"},
|
||||
"pageSize": {"$ref": "#/components/schemas/pageSize"},
|
||||
"pageToken": {"$ref": "#/components/schemas/pageToken"},
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"entity": {"$ref": "#/components/schemas/entity"},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def action_request(action: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"action",
|
||||
"connectorInputPayload",
|
||||
],
|
||||
"properties": {
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"action": {"$ref": "#/components/schemas/action"},
|
||||
"connectorInputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorInputPayload_{action}"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def action_response(action: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connectorOutputPayload": {
|
||||
"$ref": f"#/components/schemas/connectorOutputPayload_{action}"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def execute_custom_query_request() -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"operation",
|
||||
"connectionName",
|
||||
"serviceName",
|
||||
"host",
|
||||
"action",
|
||||
"query",
|
||||
"timeout",
|
||||
"pageSize",
|
||||
],
|
||||
"properties": {
|
||||
"operation": {"$ref": "#/components/schemas/operation"},
|
||||
"connectionName": {"$ref": "#/components/schemas/connectionName"},
|
||||
"serviceName": {"$ref": "#/components/schemas/serviceName"},
|
||||
"host": {"$ref": "#/components/schemas/host"},
|
||||
"action": {"$ref": "#/components/schemas/action"},
|
||||
"query": {"$ref": "#/components/schemas/query"},
|
||||
"timeout": {"$ref": "#/components/schemas/timeout"},
|
||||
"pageSize": {"$ref": "#/components/schemas/pageSize"},
|
||||
},
|
||||
}
|
||||
|
||||
def connector_payload(self, json_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self._convert_json_schema_to_openapi_schema(json_schema)
|
||||
|
||||
def _convert_json_schema_to_openapi_schema(self, json_schema):
|
||||
"""Converts a JSON schema dictionary to an OpenAPI schema dictionary, handling variable types, properties, items, nullable, and description.
|
||||
|
||||
Args:
|
||||
json_schema (dict): The input JSON schema dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The converted OpenAPI schema dictionary.
|
||||
"""
|
||||
openapi_schema = {}
|
||||
|
||||
if "description" in json_schema:
|
||||
openapi_schema["description"] = json_schema["description"]
|
||||
|
||||
if "type" in json_schema:
|
||||
if isinstance(json_schema["type"], list):
|
||||
if "null" in json_schema["type"]:
|
||||
openapi_schema["nullable"] = True
|
||||
other_types = [t for t in json_schema["type"] if t != "null"]
|
||||
if other_types:
|
||||
openapi_schema["type"] = other_types[0]
|
||||
else:
|
||||
openapi_schema["type"] = json_schema["type"][0]
|
||||
else:
|
||||
openapi_schema["type"] = json_schema["type"]
|
||||
|
||||
if openapi_schema.get("type") == "object" and "properties" in json_schema:
|
||||
openapi_schema["properties"] = {}
|
||||
for prop_name, prop_schema in json_schema["properties"].items():
|
||||
openapi_schema["properties"][prop_name] = (
|
||||
self._convert_json_schema_to_openapi_schema(prop_schema)
|
||||
)
|
||||
|
||||
elif openapi_schema.get("type") == "array" and "items" in json_schema:
|
||||
if isinstance(json_schema["items"], list):
|
||||
openapi_schema["items"] = [
|
||||
self._convert_json_schema_to_openapi_schema(item)
|
||||
for item in json_schema["items"]
|
||||
]
|
||||
else:
|
||||
openapi_schema["items"] = self._convert_json_schema_to_openapi_schema(
|
||||
json_schema["items"]
|
||||
)
|
||||
|
||||
return openapi_schema
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account_json:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
|
||||
def _execute_api_call(self, url):
|
||||
"""Executes an API call to the given URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL to call.
|
||||
|
||||
Returns:
|
||||
requests.Response: The response object from the API call.
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except google.auth.exceptions.DefaultCredentialsError as e:
|
||||
raise PermissionError(f"Credentials error: {e}") from e
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
if (
|
||||
"404" in str(e)
|
||||
or "Not found" in str(e)
|
||||
or "400" in str(e)
|
||||
or "Bad request" in str(e)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid request. Please check the provided"
|
||||
f" values of project({self.project}), location({self.location}),"
|
||||
f" connection({self.connection})."
|
||||
) from e
|
||||
raise ValueError(f"Request error: {e}") from e
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An unexpected error occurred: {e}") from e
|
||||
|
||||
def _poll_operation(self, operation_id: str) -> Dict[str, Any]:
|
||||
"""Polls an operation until it is done.
|
||||
|
||||
Args:
|
||||
operation_id: The ID of the operation to poll.
|
||||
|
||||
Returns:
|
||||
The final response of the operation.
|
||||
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
operation_done: bool = False
|
||||
operation_response: Dict[str, Any] = {}
|
||||
while not operation_done:
|
||||
get_operation_url = f"{self.connector_url}/v1/{operation_id}"
|
||||
response = self._execute_api_call(get_operation_url)
|
||||
operation_response = response.json()
|
||||
operation_done = operation_response.get("done", False)
|
||||
time.sleep(1)
|
||||
return operation_response
|
||||
@@ -0,0 +1,254 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
import google.auth
|
||||
from google.auth import default as default_service_credential
|
||||
import google.auth.transport.requests
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import requests
|
||||
|
||||
|
||||
class IntegrationClient:
|
||||
"""A client for interacting with Google Cloud Application Integration.
|
||||
|
||||
This class provides methods for retrieving OpenAPI spec for an integration or
|
||||
a connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str,
|
||||
location: str,
|
||||
integration: Optional[str] = None,
|
||||
trigger: Optional[str] = None,
|
||||
connection: Optional[str] = None,
|
||||
entity_operations: Optional[dict[str, list[str]]] = None,
|
||||
actions: Optional[list[str]] = None,
|
||||
service_account_json: Optional[str] = None,
|
||||
):
|
||||
"""Initializes the ApplicationIntegrationClient.
|
||||
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., us-central1).
|
||||
integration: The integration name.
|
||||
trigger: The trigger ID for the integration.
|
||||
connection: The connection name.
|
||||
entity_operations: A dictionary mapping entity names to a list of
|
||||
operations (e.g., LIST, CREATE, UPDATE, DELETE, GET).
|
||||
actions: List of actions.
|
||||
service_account_json: The service account configuration as a dictionary.
|
||||
Required if not using default service credential. Used for fetching
|
||||
connection details.
|
||||
"""
|
||||
self.project = project
|
||||
self.location = location
|
||||
self.integration = integration
|
||||
self.trigger = trigger
|
||||
self.connection = connection
|
||||
self.entity_operations = (
|
||||
entity_operations if entity_operations is not None else {}
|
||||
)
|
||||
self.actions = actions if actions is not None else []
|
||||
self.service_account_json = service_account_json
|
||||
self.credential_cache = None
|
||||
|
||||
def get_openapi_spec_for_integration(self):
|
||||
"""Gets the OpenAPI spec for the integration.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI spec as a dictionary.
|
||||
Raises:
|
||||
PermissionError: If there are credential issues.
|
||||
ValueError: If there's a request error or processing error.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
try:
|
||||
url = f"https://{self.location}-integrations.googleapis.com/v1/projects/{self.project}/locations/{self.location}:generateOpenApiSpec"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
data = {
|
||||
"apiTriggerResources": [
|
||||
{
|
||||
"integrationResource": self.integration,
|
||||
"triggerId": [self.trigger],
|
||||
},
|
||||
],
|
||||
"fileFormat": "JSON",
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
spec = response.json().get("openApiSpec", {})
|
||||
return json.loads(spec)
|
||||
except google.auth.exceptions.DefaultCredentialsError as e:
|
||||
raise PermissionError(f"Credentials error: {e}") from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
if (
|
||||
"404" in str(e)
|
||||
or "Not found" in str(e)
|
||||
or "400" in str(e)
|
||||
or "Bad request" in str(e)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid request. Please check the provided values of"
|
||||
f" project({self.project}), location({self.location}),"
|
||||
f" integration({self.integration}) and trigger({self.trigger})."
|
||||
) from e
|
||||
raise ValueError(f"Request error: {e}") from e
|
||||
except Exception as e:
|
||||
raise Exception(f"An unexpected error occurred: {e}") from e
|
||||
|
||||
def get_openapi_spec_for_connection(self, tool_name="", tool_instructions=""):
|
||||
"""Gets the OpenAPI spec for the connection.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI spec as a dictionary.
|
||||
Raises:
|
||||
ValueError: If there's an error retrieving the OpenAPI spec.
|
||||
PermissionError: If there are credential issues.
|
||||
Exception: For any other unexpected errors.
|
||||
"""
|
||||
# Application Integration needs to be provisioned in the same region as connection and an integration with name "ExecuteConnection" and trigger "api_trigger/ExecuteConnection" should be created as per the documentation.
|
||||
integration_name = "ExecuteConnection"
|
||||
connections_client = ConnectionsClient(
|
||||
self.project,
|
||||
self.location,
|
||||
self.connection,
|
||||
self.service_account_json,
|
||||
)
|
||||
if not self.entity_operations and not self.actions:
|
||||
raise ValueError(
|
||||
"No entity operations or actions provided. Please provide at least"
|
||||
" one of them."
|
||||
)
|
||||
connector_spec = connections_client.get_connector_base_spec()
|
||||
for entity, operations in self.entity_operations.items():
|
||||
schema, supported_operations = (
|
||||
connections_client.get_entity_schema_and_operations(entity)
|
||||
)
|
||||
if not operations:
|
||||
operations = supported_operations
|
||||
json_schema_as_string = json.dumps(schema)
|
||||
entity_lower = entity
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorInputPayload_{entity_lower}"
|
||||
] = connections_client.connector_payload(schema)
|
||||
for operation in operations:
|
||||
operation_lower = operation.lower()
|
||||
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{operation_lower}_{entity_lower}"
|
||||
if operation_lower == "create":
|
||||
connector_spec["paths"][path] = connections_client.create_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"create_{entity_lower}_Request"
|
||||
] = connections_client.create_operation_request(entity_lower)
|
||||
elif operation_lower == "update":
|
||||
connector_spec["paths"][path] = connections_client.update_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"update_{entity_lower}_Request"
|
||||
] = connections_client.update_operation_request(entity_lower)
|
||||
elif operation_lower == "delete":
|
||||
connector_spec["paths"][path] = connections_client.delete_operation(
|
||||
entity_lower, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"delete_{entity_lower}_Request"
|
||||
] = connections_client.delete_operation_request()
|
||||
elif operation_lower == "list":
|
||||
connector_spec["paths"][path] = connections_client.list_operation(
|
||||
entity_lower, json_schema_as_string, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"list_{entity_lower}_Request"
|
||||
] = connections_client.list_operation_request()
|
||||
elif operation_lower == "get":
|
||||
connector_spec["paths"][path] = connections_client.get_operation(
|
||||
entity_lower, json_schema_as_string, tool_name, tool_instructions
|
||||
)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"get_{entity_lower}_Request"
|
||||
] = connections_client.get_operation_request()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operation: {operation} for entity: {entity}"
|
||||
)
|
||||
for action in self.actions:
|
||||
action_details = connections_client.get_action_schema(action)
|
||||
input_schema = action_details["inputSchema"]
|
||||
output_schema = action_details["outputSchema"]
|
||||
# Remove spaces from the display name to generate valid spec
|
||||
action_display_name = action_details["displayName"].replace(" ", "")
|
||||
operation = "EXECUTE_ACTION"
|
||||
if action == "ExecuteCustomQuery":
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action_display_name}_Request"
|
||||
] = connections_client.execute_custom_query_request()
|
||||
operation = "EXECUTE_QUERY"
|
||||
else:
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action_display_name}_Request"
|
||||
] = connections_client.action_request(action_display_name)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorInputPayload_{action_display_name}"
|
||||
] = connections_client.connector_payload(input_schema)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"connectorOutputPayload_{action_display_name}"
|
||||
] = connections_client.connector_payload(output_schema)
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action_display_name}_Response"
|
||||
] = connections_client.action_response(action_display_name)
|
||||
path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{action}"
|
||||
connector_spec["paths"][path] = connections_client.get_action_operation(
|
||||
action, operation, action_display_name, tool_name, tool_instructions
|
||||
)
|
||||
return connector_spec
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""Gets the access token for the service account or using default credentials.
|
||||
|
||||
Returns:
|
||||
The access token.
|
||||
"""
|
||||
if self.credential_cache and not self.credential_cache.expired:
|
||||
return self.credential_cache.token
|
||||
|
||||
if self.service_account_json:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential()
|
||||
except:
|
||||
credentials = None
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
self.credential_cache = credentials
|
||||
return credentials.token
|
||||
144
.venv/lib/python3.10/site-packages/google/adk/tools/base_tool.py
Normal file
144
.venv/lib/python3.10/site-packages/google/adk/tools/base_tool.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""The base class for all tools."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str
|
||||
"""The description of the tool."""
|
||||
|
||||
is_long_running: bool = False
|
||||
"""Whether the tool is a long running operation, which typically returns a
|
||||
resource id first and finishes the operation later."""
|
||||
|
||||
def __init__(self, *, name, description, is_long_running: bool = False):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.is_long_running = is_long_running
|
||||
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
|
||||
|
||||
NOTE
|
||||
- Required if subclass uses the default implementation of
|
||||
`process_llm_request` to add function declaration to LLM request.
|
||||
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
|
||||
Gemini.
|
||||
|
||||
Returns:
|
||||
The FunctionDeclaration of this tool, or None if it doesn't need to be
|
||||
added to LlmRequest.config.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
"""Runs the tool with the given arguments and context.
|
||||
|
||||
NOTE
|
||||
- Required if this tool needs to run at the client side.
|
||||
- Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
|
||||
Gemini.
|
||||
|
||||
Args:
|
||||
args: The LLM-filled arguments.
|
||||
tool_context: The context of the tool.
|
||||
|
||||
Returns:
|
||||
The result of running the tool.
|
||||
"""
|
||||
raise NotImplementedError(f'{type(self)} is not implemented')
|
||||
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
"""Processes the outgoing LLM request for this tool.
|
||||
|
||||
Use cases:
|
||||
- Most common use case is adding this tool to the LLM request.
|
||||
- Some tools may just preprocess the LLM request before it's sent out.
|
||||
|
||||
Args:
|
||||
tool_context: The context of the tool.
|
||||
llm_request: The outgoing LLM request, mutable this method.
|
||||
"""
|
||||
if (function_declaration := self._get_declaration()) is None:
|
||||
return
|
||||
|
||||
llm_request.tools_dict[self.name] = self
|
||||
if tool_with_function_declarations := _find_tool_with_function_declarations(
|
||||
llm_request
|
||||
):
|
||||
if tool_with_function_declarations.function_declarations is None:
|
||||
tool_with_function_declarations.function_declarations = []
|
||||
tool_with_function_declarations.function_declarations.append(
|
||||
function_declaration
|
||||
)
|
||||
else:
|
||||
llm_request.config = (
|
||||
types.GenerateContentConfig()
|
||||
if not llm_request.config
|
||||
else llm_request.config
|
||||
)
|
||||
llm_request.config.tools = (
|
||||
[] if not llm_request.config.tools else llm_request.config.tools
|
||||
)
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(function_declarations=[function_declaration])
|
||||
)
|
||||
|
||||
@property
|
||||
def _api_variant(self) -> str:
|
||||
use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
|
||||
'true',
|
||||
'1',
|
||||
]
|
||||
return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI'
|
||||
|
||||
|
||||
def _find_tool_with_function_declarations(
|
||||
llm_request: LlmRequest,
|
||||
) -> Optional[types.Tool]:
|
||||
# TODO: add individual tool with declaration and merge in google_llm.py
|
||||
if not llm_request.config or not llm_request.config.tools:
|
||||
return None
|
||||
|
||||
return next(
|
||||
(
|
||||
tool
|
||||
for tool in llm_request.config.tools
|
||||
if isinstance(tool, types.Tool) and tool.function_declarations
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class BuiltInCodeExecutionTool(BaseTool):
|
||||
"""A built-in code execution tool that is automatically invoked by Gemini 2 models.
|
||||
|
||||
This tool operates internally within the model and does not require or perform
|
||||
local code execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Name and description are not used because this is a model built-in tool.
|
||||
super().__init__(name='code_execution', description='code_execution')
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
if llm_request.model and llm_request.model.startswith('gemini-2'):
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.tools = llm_request.config.tools or []
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(code_execution=types.ToolCodeExecution())
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Code execution tool is not supported for model {llm_request.model}'
|
||||
)
|
||||
|
||||
|
||||
built_in_code_execution = BuiltInCodeExecutionTool()
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
try:
|
||||
from crewai.tools import BaseTool as CrewaiBaseTool
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"Crewai Tools require Python 3.10+. Please upgrade your Python version."
|
||||
) from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"Crewai Tools require pip install 'google-adk[extensions]'."
|
||||
) from e
|
||||
|
||||
|
||||
class CrewaiTool(FunctionTool):
|
||||
"""Use this class to wrap a CrewAI tool.
|
||||
|
||||
If the original tool name and description are not suitable, you can override
|
||||
them in the constructor.
|
||||
"""
|
||||
|
||||
tool: CrewaiBaseTool
|
||||
"""The wrapped CrewAI tool."""
|
||||
|
||||
def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
|
||||
super().__init__(tool.run)
|
||||
self.tool = tool
|
||||
if name:
|
||||
self.name = name
|
||||
elif tool.name:
|
||||
# Right now, CrewAI tool name contains white spaces. White spaces are
|
||||
# not supported in our framework. So we replace them with "_".
|
||||
self.name = tool.name.replace(" ", "_").lower()
|
||||
if description:
|
||||
self.description = description
|
||||
elif tool.description:
|
||||
self.description = tool.description
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
|
||||
False,
|
||||
self.name,
|
||||
self.description,
|
||||
self.func,
|
||||
self.tool.args_schema.model_json_schema(),
|
||||
)
|
||||
return function_declaration
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from typing_extensions import override
|
||||
|
||||
from ..examples import example_util
|
||||
from ..examples.base_example_provider import BaseExampleProvider
|
||||
from ..examples.example import Example
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class ExampleTool(BaseTool):
|
||||
"""A tool that adds (few-shot) examples to the LLM request.
|
||||
|
||||
Attributes:
|
||||
examples: The examples to add to the LLM request.
|
||||
"""
|
||||
|
||||
def __init__(self, examples: Union[list[Example], BaseExampleProvider]):
|
||||
# Name and description are not used because this tool only changes
|
||||
# llm_request.
|
||||
super().__init__(name='example_tool', description='example tool')
|
||||
self.examples = (
|
||||
TypeAdapter(list[Example]).validate_python(examples)
|
||||
if isinstance(examples, list)
|
||||
else examples
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
parts = tool_context.user_content.parts
|
||||
if not parts or not parts[0].text:
|
||||
return
|
||||
|
||||
llm_request.append_instructions([
|
||||
example_util.build_example_si(
|
||||
self.examples, parts[0].text, llm_request.model
|
||||
)
|
||||
])
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
def exit_loop(tool_context: ToolContext):
|
||||
"""Exits the loop.
|
||||
|
||||
Call this function only when you are instructed to do so.
|
||||
"""
|
||||
tool_context.actions.escalate = True
|
||||
@@ -0,0 +1,307 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import types as typing_types
|
||||
from typing import _GenericAlias
|
||||
from typing import Any
|
||||
from typing import get_args
|
||||
from typing import get_origin
|
||||
from typing import Literal
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
import pydantic
|
||||
|
||||
_py_builtin_type_to_schema_type = {
|
||||
str: types.Type.STRING,
|
||||
int: types.Type.INTEGER,
|
||||
float: types.Type.NUMBER,
|
||||
bool: types.Type.BOOLEAN,
|
||||
list: types.Type.ARRAY,
|
||||
dict: types.Type.OBJECT,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_builtin_primitive_or_compound(
|
||||
annotation: inspect.Parameter.annotation,
|
||||
) -> bool:
|
||||
return annotation in _py_builtin_type_to_schema_type.keys()
|
||||
|
||||
|
||||
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
||||
if schema.any_of:
|
||||
raise ValueError(
|
||||
'AnyOf is not supported in function declaration schema for Google AI.'
|
||||
)
|
||||
|
||||
|
||||
def _update_for_default_if_mldev(schema: types.Schema):
|
||||
if schema.default is not None:
|
||||
# TODO(kech): Remove this workaround once mldev supports default value.
|
||||
schema.default = None
|
||||
logger.warning(
|
||||
'Default value is not supported in function declaration schema for'
|
||||
' Google AI.'
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
|
||||
if not variant == 'VERTEX_AI':
|
||||
_raise_for_any_of_if_mldev(schema)
|
||||
_update_for_default_if_mldev(schema)
|
||||
|
||||
|
||||
def _is_default_value_compatible(
|
||||
default_value: Any, annotation: inspect.Parameter.annotation
|
||||
) -> bool:
|
||||
# None type is expected to be handled external to this function
|
||||
if _is_builtin_primitive_or_compound(annotation):
|
||||
return isinstance(default_value, annotation)
|
||||
|
||||
if (
|
||||
isinstance(annotation, _GenericAlias)
|
||||
or isinstance(annotation, typing_types.GenericAlias)
|
||||
or isinstance(annotation, typing_types.UnionType)
|
||||
):
|
||||
origin = get_origin(annotation)
|
||||
if origin in (Union, typing_types.UnionType):
|
||||
return any(
|
||||
_is_default_value_compatible(default_value, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
|
||||
if origin is dict:
|
||||
return isinstance(default_value, dict)
|
||||
|
||||
if origin is list:
|
||||
if not isinstance(default_value, list):
|
||||
return False
|
||||
# most tricky case, element in list is union type
|
||||
# need to apply any logic within all
|
||||
# see test case test_generic_alias_complex_array_with_default_value
|
||||
# a: typing.List[int | str | float | bool]
|
||||
# default_value: [1, 'a', 1.1, True]
|
||||
return all(
|
||||
any(
|
||||
_is_default_value_compatible(item, arg)
|
||||
for arg in get_args(annotation)
|
||||
)
|
||||
for item in default_value
|
||||
)
|
||||
|
||||
if origin is Literal:
|
||||
return default_value in get_args(annotation)
|
||||
|
||||
# return False for any other unrecognized annotation
|
||||
# let caller handle the raise
|
||||
return False
|
||||
|
||||
|
||||
def _parse_schema_from_parameter(
|
||||
variant: str, param: inspect.Parameter, func_name: str
|
||||
) -> types.Schema:
|
||||
"""parse schema from parameter.
|
||||
|
||||
from the simplest case to the most complex case.
|
||||
"""
|
||||
schema = types.Schema()
|
||||
default_value_error_msg = (
|
||||
f'Default value {param.default} of parameter {param} of function'
|
||||
f' {func_name} is not compatible with the parameter annotation'
|
||||
f' {param.annotation}.'
|
||||
)
|
||||
if _is_builtin_primitive_or_compound(param.annotation):
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if (
|
||||
get_origin(param.annotation) is Union
|
||||
# only parse simple UnionType, example int | str | float | bool
|
||||
# complex types.UnionType will be invoked in raise branch
|
||||
and all(
|
||||
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
||||
for arg in get_args(param.annotation)
|
||||
)
|
||||
):
|
||||
schema.type = types.Type.OBJECT
|
||||
schema.any_of = []
|
||||
unique_types = set()
|
||||
for arg in get_args(param.annotation):
|
||||
if arg.__name__ == 'NoneType': # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: list | None -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
||||
param.annotation, typing_types.GenericAlias
|
||||
):
|
||||
origin = get_origin(param.annotation)
|
||||
args = get_args(param.annotation)
|
||||
if origin is dict:
|
||||
schema.type = types.Type.OBJECT
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is Literal:
|
||||
if not all(isinstance(arg, str) for arg in args):
|
||||
raise ValueError(
|
||||
f'Literal type {param.annotation} must be a list of strings.'
|
||||
)
|
||||
schema.type = types.Type.STRING
|
||||
schema.enum = list(args)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is list:
|
||||
schema.type = types.Type.ARRAY
|
||||
schema.items = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=args[0],
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if origin is Union:
|
||||
schema.any_of = []
|
||||
schema.type = types.Type.OBJECT
|
||||
unique_types = set()
|
||||
for arg in args:
|
||||
if arg.__name__ == 'NoneType': # Optional type
|
||||
schema.nullable = True
|
||||
continue
|
||||
schema_in_any_of = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
'item',
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=arg,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
if (
|
||||
len(param.annotation.__args__) == 2
|
||||
and type(None) in param.annotation.__args__
|
||||
): # Optional type
|
||||
for optional_arg in param.annotation.__args__:
|
||||
if (
|
||||
hasattr(optional_arg, '__origin__')
|
||||
and optional_arg.__origin__ is list
|
||||
):
|
||||
# Optional type with list, for example Optional[list[str]]
|
||||
schema.items = schema_in_any_of.items
|
||||
if (
|
||||
schema_in_any_of.model_dump_json(exclude_none=True)
|
||||
not in unique_types
|
||||
):
|
||||
schema.any_of.append(schema_in_any_of)
|
||||
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
||||
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
|
||||
schema.type = schema.any_of[0].type
|
||||
schema.any_of = None
|
||||
if (
|
||||
param.default is not None
|
||||
and param.default is not inspect.Parameter.empty
|
||||
):
|
||||
if not _is_default_value_compatible(param.default, param.annotation):
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = param.default
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
# all other generic alias will be invoked in raise branch
|
||||
if (
|
||||
inspect.isclass(param.annotation)
|
||||
# for user defined class, we only support pydantic model
|
||||
and issubclass(param.annotation, pydantic.BaseModel)
|
||||
):
|
||||
if (
|
||||
param.default is not inspect.Parameter.empty
|
||||
and param.default is not None
|
||||
):
|
||||
schema.default = param.default
|
||||
schema.type = types.Type.OBJECT
|
||||
schema.properties = {}
|
||||
for field_name, field_info in param.annotation.model_fields.items():
|
||||
schema.properties[field_name] = _parse_schema_from_parameter(
|
||||
variant,
|
||||
inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=field_info.annotation,
|
||||
),
|
||||
func_name,
|
||||
)
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
raise ValueError(
|
||||
f'Failed to parse the parameter {param} of function {func_name} for'
|
||||
' automatic function calling. Automatic function calling works best with'
|
||||
' simpler function signature schema,consider manually parse your'
|
||||
f' function declaration for function {func_name}.'
|
||||
)
|
||||
|
||||
|
||||
def _get_required_fields(schema: types.Schema) -> list[str]:
|
||||
if not schema.properties:
|
||||
return
|
||||
return [
|
||||
field_name
|
||||
for field_name, field_schema in schema.properties.items()
|
||||
if not field_schema.nullable and field_schema.default is None
|
||||
]
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ._automatic_function_calling_util import build_function_declaration
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
class FunctionTool(BaseTool):
|
||||
"""A tool that wraps a user-defined Python function.
|
||||
|
||||
Attributes:
|
||||
func: The function to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[..., Any]):
|
||||
super().__init__(name=func.__name__, description=func.__doc__)
|
||||
self.func = func
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
function_decl = types.FunctionDeclaration.model_validate(
|
||||
build_function_declaration(
|
||||
func=self.func,
|
||||
# The model doesn't understand the function context.
|
||||
# input_stream is for streaming tool
|
||||
ignore_params=['tool_context', 'input_stream'],
|
||||
variant=self._api_variant,
|
||||
)
|
||||
)
|
||||
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**args_to_call) or {}
|
||||
else:
|
||||
return self.func(**args_to_call) or {}
|
||||
|
||||
# TODO(hangfei): fix call live for function stream.
|
||||
async def _call_live(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
invocation_context,
|
||||
) -> Any:
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if (
|
||||
self.name in invocation_context.active_streaming_tools
|
||||
and invocation_context.active_streaming_tools[self.name].stream
|
||||
):
|
||||
args_to_call['input_stream'] = invocation_context.active_streaming_tools[
|
||||
self.name
|
||||
].stream
|
||||
if 'tool_context' in signature.parameters:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
async for item in self.func(**args_to_call):
|
||||
yield item
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from .long_running_tool import LongRunningFunctionTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
def get_user_choice(
|
||||
options: list[str], tool_context: ToolContext
|
||||
) -> Optional[str]:
|
||||
"""Provides the options to the user and asks them to choose one."""
|
||||
tool_context.actions.skip_summarization = True
|
||||
return None
|
||||
|
||||
|
||||
get_user_choice_tool = LongRunningFunctionTool(func=get_user_choice)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__all__ = [
|
||||
'bigquery_tool_set',
|
||||
'calendar_tool_set',
|
||||
'gmail_tool_set',
|
||||
'youtube_tool_set',
|
||||
'slides_tool_set',
|
||||
'sheets_tool_set',
|
||||
'docs_tool_set',
|
||||
]
|
||||
|
||||
# Nothing is imported here automatically
|
||||
# Each tool set will only be imported when accessed
|
||||
|
||||
_bigquery_tool_set = None
|
||||
_calendar_tool_set = None
|
||||
_gmail_tool_set = None
|
||||
_youtube_tool_set = None
|
||||
_slides_tool_set = None
|
||||
_sheets_tool_set = None
|
||||
_docs_tool_set = None
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||
|
||||
match name:
|
||||
case 'bigquery_tool_set':
|
||||
if _bigquery_tool_set is None:
|
||||
from .google_api_tool_sets import bigquery_tool_set as bigquery
|
||||
|
||||
_bigquery_tool_set = bigquery
|
||||
return _bigquery_tool_set
|
||||
|
||||
case 'calendar_tool_set':
|
||||
if _calendar_tool_set is None:
|
||||
from .google_api_tool_sets import calendar_tool_set as calendar
|
||||
|
||||
_calendar_tool_set = calendar
|
||||
return _calendar_tool_set
|
||||
|
||||
case 'gmail_tool_set':
|
||||
if _gmail_tool_set is None:
|
||||
from .google_api_tool_sets import gmail_tool_set as gmail
|
||||
|
||||
_gmail_tool_set = gmail
|
||||
return _gmail_tool_set
|
||||
|
||||
case 'youtube_tool_set':
|
||||
if _youtube_tool_set is None:
|
||||
from .google_api_tool_sets import youtube_tool_set as youtube
|
||||
|
||||
_youtube_tool_set = youtube
|
||||
return _youtube_tool_set
|
||||
|
||||
case 'slides_tool_set':
|
||||
if _slides_tool_set is None:
|
||||
from .google_api_tool_sets import slides_tool_set as slides
|
||||
|
||||
_slides_tool_set = slides
|
||||
return _slides_tool_set
|
||||
|
||||
case 'sheets_tool_set':
|
||||
if _sheets_tool_set is None:
|
||||
from .google_api_tool_sets import sheets_tool_set as sheets
|
||||
|
||||
_sheets_tool_set = sheets
|
||||
return _sheets_tool_set
|
||||
|
||||
case 'docs_tool_set':
|
||||
if _docs_tool_set is None:
|
||||
from .google_api_tool_sets import docs_tool_set as docs
|
||||
|
||||
_docs_tool_set = docs
|
||||
return _docs_tool_set
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,59 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from ...auth import AuthCredential
|
||||
from ...auth import AuthCredentialTypes
|
||||
from ...auth import OAuth2Auth
|
||||
from .. import BaseTool
|
||||
from ..openapi_tool import RestApiTool
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
|
||||
class GoogleApiTool(BaseTool):
|
||||
|
||||
def __init__(self, rest_api_tool: RestApiTool):
|
||||
super().__init__(
|
||||
name=rest_api_tool.name,
|
||||
description=rest_api_tool.description,
|
||||
is_long_running=rest_api_tool.is_long_running,
|
||||
)
|
||||
self.rest_api_tool = rest_api_tool
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
return self.rest_api_tool._get_declaration()
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
return await self.rest_api_tool.run_async(
|
||||
args=args, tool_context=tool_context
|
||||
)
|
||||
|
||||
def configure_auth(self, client_id: str, client_secret: str):
|
||||
self.rest_api_tool.auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from ...auth import OpenIdConnectWithConfig
|
||||
from ..openapi_tool import OpenAPIToolset
|
||||
from ..openapi_tool import RestApiTool
|
||||
from .google_api_tool import GoogleApiTool
|
||||
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
|
||||
|
||||
class GoogleApiToolSet:
|
||||
"""Google API Tool Set."""
|
||||
|
||||
def __init__(self, tools: List[RestApiTool]):
|
||||
self.tools: Final[List[GoogleApiTool]] = [
|
||||
GoogleApiTool(tool) for tool in tools
|
||||
]
|
||||
|
||||
def get_tools(self) -> List[GoogleApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
|
||||
@staticmethod
|
||||
def _load_tool_set_with_oidc_auth(
|
||||
spec_file: Optional[str] = None,
|
||||
spec_dict: Optional[dict[str, Any]] = None,
|
||||
scopes: Optional[list[str]] = None,
|
||||
) -> OpenAPIToolset:
|
||||
spec_str = None
|
||||
if spec_file:
|
||||
# Get the frame of the caller
|
||||
caller_frame = inspect.stack()[1]
|
||||
# Get the filename of the caller
|
||||
caller_filename = caller_frame.filename
|
||||
# Get the directory of the caller
|
||||
caller_dir = os.path.dirname(os.path.abspath(caller_filename))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(caller_dir, spec_file)
|
||||
with open(yaml_path, 'r', encoding='utf-8') as file:
|
||||
spec_str = file.read()
|
||||
tool_set = OpenAPIToolset(
|
||||
spec_dict=spec_dict,
|
||||
spec_str=spec_str,
|
||||
spec_str_type='yaml',
|
||||
auth_scheme=OpenIdConnectWithConfig(
|
||||
authorization_endpoint=(
|
||||
'https://accounts.google.com/o/oauth2/v2/auth'
|
||||
),
|
||||
token_endpoint='https://oauth2.googleapis.com/token',
|
||||
userinfo_endpoint=(
|
||||
'https://openidconnect.googleapis.com/v1/userinfo'
|
||||
),
|
||||
revocation_endpoint='https://oauth2.googleapis.com/revoke',
|
||||
token_endpoint_auth_methods_supported=[
|
||||
'client_secret_post',
|
||||
'client_secret_basic',
|
||||
],
|
||||
grant_types_supported=['authorization_code'],
|
||||
scopes=scopes,
|
||||
),
|
||||
)
|
||||
return tool_set
|
||||
|
||||
def configure_auth(self, client_id: str, client_secret: str):
|
||||
for tool in self.tools:
|
||||
tool.configure_auth(client_id, client_secret)
|
||||
|
||||
@classmethod
|
||||
def load_tool_set(
|
||||
cls: Type[GoogleApiToolSet],
|
||||
api_name: str,
|
||||
api_version: str,
|
||||
) -> GoogleApiToolSet:
|
||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||
scope = list(
|
||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||
'authorizationCode'
|
||||
]['scopes'].keys()
|
||||
)[0]
|
||||
return cls(
|
||||
cls._load_tool_set_with_oidc_auth(
|
||||
spec_dict=spec_dict, scopes=[scope]
|
||||
).get_tools()
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .google_api_tool_set import GoogleApiToolSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_bigquery_tool_set = None
|
||||
_calendar_tool_set = None
|
||||
_gmail_tool_set = None
|
||||
_youtube_tool_set = None
|
||||
_slides_tool_set = None
|
||||
_sheets_tool_set = None
|
||||
_docs_tool_set = None
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""This method dynamically loads and returns GoogleApiToolSet instances for
|
||||
|
||||
various Google APIs. It uses a lazy loading approach, initializing each
|
||||
tool set only when it is first requested. This avoids unnecessary loading
|
||||
of tool sets that are not used in a given session.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool set to retrieve (e.g.,
|
||||
"bigquery_tool_set").
|
||||
|
||||
Returns:
|
||||
GoogleApiToolSet: The requested tool set instance.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the requested tool set name is not recognized.
|
||||
"""
|
||||
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||
|
||||
match name:
|
||||
case "bigquery_tool_set":
|
||||
if _bigquery_tool_set is None:
|
||||
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="bigquery",
|
||||
api_version="v2",
|
||||
)
|
||||
|
||||
return _bigquery_tool_set
|
||||
|
||||
case "calendar_tool_set":
|
||||
if _calendar_tool_set is None:
|
||||
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="calendar",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
return _calendar_tool_set
|
||||
|
||||
case "gmail_tool_set":
|
||||
if _gmail_tool_set is None:
|
||||
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="gmail",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _gmail_tool_set
|
||||
|
||||
case "youtube_tool_set":
|
||||
if _youtube_tool_set is None:
|
||||
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="youtube",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
return _youtube_tool_set
|
||||
|
||||
case "slides_tool_set":
|
||||
if _slides_tool_set is None:
|
||||
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="slides",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _slides_tool_set
|
||||
|
||||
case "sheets_tool_set":
|
||||
if _sheets_tool_set is None:
|
||||
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="sheets",
|
||||
api_version="v4",
|
||||
)
|
||||
|
||||
return _sheets_tool_set
|
||||
|
||||
case "docs_tool_set":
|
||||
if _docs_tool_set is None:
|
||||
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="docs",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _docs_tool_set
|
||||
@@ -0,0 +1,523 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
# Google API client
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import Resource
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleApiToOpenApiConverter:
|
||||
"""Converts Google API Discovery documents to OpenAPI v3 format."""
|
||||
|
||||
def __init__(self, api_name: str, api_version: str):
|
||||
"""Initialize the converter with the API name and version.
|
||||
|
||||
Args:
|
||||
api_name: The name of the Google API (e.g., "calendar")
|
||||
api_version: The version of the API (e.g., "v3")
|
||||
"""
|
||||
self.api_name = api_name
|
||||
self.api_version = api_version
|
||||
self.google_api_resource = None
|
||||
self.google_api_spec = None
|
||||
self.openapi_spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {},
|
||||
"servers": [],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}, "securitySchemes": {}},
|
||||
}
|
||||
|
||||
def fetch_google_api_spec(self) -> None:
|
||||
"""Fetches the Google API specification using discovery service."""
|
||||
try:
|
||||
logger.info(
|
||||
"Fetching Google API spec for %s %s", self.api_name, self.api_version
|
||||
)
|
||||
# Build a resource object for the specified API
|
||||
self.google_api_resource = build(self.api_name, self.api_version)
|
||||
|
||||
# Access the underlying API discovery document
|
||||
self.google_api_spec = self.google_api_resource._rootDesc
|
||||
|
||||
if not self.google_api_spec:
|
||||
raise ValueError("Failed to retrieve API specification")
|
||||
|
||||
logger.info("Successfully fetched %s API specification", self.api_name)
|
||||
except HttpError as e:
|
||||
logger.error("HTTP Error: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error fetching API spec: %s", e)
|
||||
raise
|
||||
|
||||
def convert(self) -> Dict[str, Any]:
|
||||
"""Convert the Google API spec to OpenAPI v3 format.
|
||||
|
||||
Returns:
|
||||
Dict containing the converted OpenAPI v3 specification
|
||||
"""
|
||||
if not self.google_api_spec:
|
||||
self.fetch_google_api_spec()
|
||||
|
||||
# Convert basic API information
|
||||
self._convert_info()
|
||||
|
||||
# Convert server information
|
||||
self._convert_servers()
|
||||
|
||||
# Convert authentication/authorization schemes
|
||||
self._convert_security_schemes()
|
||||
|
||||
# Convert schemas (models)
|
||||
self._convert_schemas()
|
||||
|
||||
# Convert endpoints/paths
|
||||
self._convert_resources(self.google_api_spec.get("resources", {}))
|
||||
|
||||
# Convert top-level methods, if any
|
||||
self._convert_methods(self.google_api_spec.get("methods", {}), "/")
|
||||
|
||||
return self.openapi_spec
|
||||
|
||||
def _convert_info(self) -> None:
|
||||
"""Convert basic API information."""
|
||||
self.openapi_spec["info"] = {
|
||||
"title": self.google_api_spec.get("title", f"{self.api_name} API"),
|
||||
"description": self.google_api_spec.get("description", ""),
|
||||
"version": self.google_api_spec.get("version", self.api_version),
|
||||
"contact": {},
|
||||
"termsOfService": self.google_api_spec.get("documentationLink", ""),
|
||||
}
|
||||
|
||||
# Add documentation links if available
|
||||
docs_link = self.google_api_spec.get("documentationLink")
|
||||
if docs_link:
|
||||
self.openapi_spec["externalDocs"] = {
|
||||
"description": "API Documentation",
|
||||
"url": docs_link,
|
||||
}
|
||||
|
||||
def _convert_servers(self) -> None:
|
||||
"""Convert server information."""
|
||||
base_url = self.google_api_spec.get(
|
||||
"rootUrl", ""
|
||||
) + self.google_api_spec.get("servicePath", "")
|
||||
|
||||
# Remove trailing slash if present
|
||||
if base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
self.openapi_spec["servers"] = [{
|
||||
"url": base_url,
|
||||
"description": f"{self.api_name} {self.api_version} API",
|
||||
}]
|
||||
|
||||
def _convert_security_schemes(self) -> None:
|
||||
"""Convert authentication and authorization schemes."""
|
||||
auth = self.google_api_spec.get("auth", {})
|
||||
oauth2 = auth.get("oauth2", {})
|
||||
|
||||
if oauth2:
|
||||
# Handle OAuth2
|
||||
scopes = oauth2.get("scopes", {})
|
||||
formatted_scopes = {}
|
||||
|
||||
for scope, scope_info in scopes.items():
|
||||
formatted_scopes[scope] = scope_info.get("description", "")
|
||||
|
||||
self.openapi_spec["components"]["securitySchemes"]["oauth2"] = {
|
||||
"type": "oauth2",
|
||||
"description": "OAuth 2.0 authentication",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": (
|
||||
"https://accounts.google.com/o/oauth2/auth"
|
||||
),
|
||||
"tokenUrl": "https://oauth2.googleapis.com/token",
|
||||
"scopes": formatted_scopes,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Add API key authentication (most Google APIs support this)
|
||||
self.openapi_spec["components"]["securitySchemes"]["apiKey"] = {
|
||||
"type": "apiKey",
|
||||
"in": "query",
|
||||
"name": "key",
|
||||
"description": "API key for accessing this API",
|
||||
}
|
||||
|
||||
# Create global security requirement
|
||||
self.openapi_spec["security"] = [
|
||||
{"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
|
||||
{"apiKey": []},
|
||||
]
|
||||
|
||||
def _convert_schemas(self) -> None:
|
||||
"""Convert schema definitions (models)."""
|
||||
schemas = self.google_api_spec.get("schemas", {})
|
||||
|
||||
for schema_name, schema_def in schemas.items():
|
||||
converted_schema = self._convert_schema_object(schema_def)
|
||||
self.openapi_spec["components"]["schemas"][schema_name] = converted_schema
|
||||
|
||||
def _convert_schema_object(
|
||||
self, schema_def: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Recursively convert a Google API schema object to OpenAPI schema.
|
||||
|
||||
Args:
|
||||
schema_def: Google API schema definition
|
||||
|
||||
Returns:
|
||||
Converted OpenAPI schema object
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# Convert the type
|
||||
if "type" in schema_def:
|
||||
gtype = schema_def["type"]
|
||||
if gtype == "object":
|
||||
result["type"] = "object"
|
||||
|
||||
# Handle properties
|
||||
if "properties" in schema_def:
|
||||
result["properties"] = {}
|
||||
for prop_name, prop_def in schema_def["properties"].items():
|
||||
result["properties"][prop_name] = self._convert_schema_object(
|
||||
prop_def
|
||||
)
|
||||
|
||||
# Handle required fields
|
||||
required_fields = []
|
||||
for prop_name, prop_def in schema_def.get("properties", {}).items():
|
||||
if prop_def.get("required", False):
|
||||
required_fields.append(prop_name)
|
||||
if required_fields:
|
||||
result["required"] = required_fields
|
||||
|
||||
elif gtype == "array":
|
||||
result["type"] = "array"
|
||||
if "items" in schema_def:
|
||||
result["items"] = self._convert_schema_object(schema_def["items"])
|
||||
|
||||
elif gtype == "any":
|
||||
# OpenAPI doesn't have direct "any" type
|
||||
# Use oneOf with multiple options as alternative
|
||||
result["oneOf"] = [
|
||||
{"type": "object"},
|
||||
{"type": "array"},
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "boolean"},
|
||||
{"type": "null"},
|
||||
]
|
||||
|
||||
else:
|
||||
# Handle other primitive types
|
||||
result["type"] = gtype
|
||||
|
||||
# Handle references
|
||||
if "$ref" in schema_def:
|
||||
ref = schema_def["$ref"]
|
||||
# Google refs use "#" at start, OpenAPI uses "#/components/schemas/"
|
||||
if ref.startswith("#"):
|
||||
ref = ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
ref = "#/components/schemas/" + ref
|
||||
result["$ref"] = ref
|
||||
|
||||
# Handle format
|
||||
if "format" in schema_def:
|
||||
result["format"] = schema_def["format"]
|
||||
|
||||
# Handle enum values
|
||||
if "enum" in schema_def:
|
||||
result["enum"] = schema_def["enum"]
|
||||
|
||||
# Handle description
|
||||
if "description" in schema_def:
|
||||
result["description"] = schema_def["description"]
|
||||
|
||||
# Handle pattern
|
||||
if "pattern" in schema_def:
|
||||
result["pattern"] = schema_def["pattern"]
|
||||
|
||||
# Handle default value
|
||||
if "default" in schema_def:
|
||||
result["default"] = schema_def["default"]
|
||||
|
||||
return result
|
||||
|
||||
def _convert_resources(
|
||||
self, resources: Dict[str, Any], parent_path: str = ""
|
||||
) -> None:
|
||||
"""Recursively convert all resources and their methods.
|
||||
|
||||
Args:
|
||||
resources: Dictionary of resources from the Google API spec
|
||||
parent_path: The parent path prefix for nested resources
|
||||
"""
|
||||
for resource_name, resource_data in resources.items():
|
||||
# Process methods for this resource
|
||||
resource_path = f"{parent_path}/{resource_name}"
|
||||
methods = resource_data.get("methods", {})
|
||||
self._convert_methods(methods, resource_path)
|
||||
|
||||
# Process nested resources recursively
|
||||
nested_resources = resource_data.get("resources", {})
|
||||
if nested_resources:
|
||||
self._convert_resources(nested_resources, resource_path)
|
||||
|
||||
def _convert_methods(
|
||||
self, methods: Dict[str, Any], resource_path: str
|
||||
) -> None:
|
||||
"""Convert methods for a specific resource path.
|
||||
|
||||
Args:
|
||||
methods: Dictionary of methods from the Google API spec
|
||||
resource_path: The path of the resource these methods belong to
|
||||
"""
|
||||
for method_name, method_data in methods.items():
|
||||
http_method = method_data.get("httpMethod", "GET").lower()
|
||||
|
||||
# Determine the actual endpoint path
|
||||
# Google often has the format something like 'users.messages.list'
|
||||
# flatPath is preferred as it provides the actual path, while path
|
||||
# might contain variables like {+projectId}
|
||||
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
|
||||
if not rest_path.startswith("/"):
|
||||
rest_path = "/" + rest_path
|
||||
|
||||
path_params = self._extract_path_parameters(rest_path)
|
||||
|
||||
# Create path entry if it doesn't exist
|
||||
if rest_path not in self.openapi_spec["paths"]:
|
||||
self.openapi_spec["paths"][rest_path] = {}
|
||||
|
||||
# Add the operation for this method
|
||||
self.openapi_spec["paths"][rest_path][http_method] = (
|
||||
self._convert_operation(method_data, path_params)
|
||||
)
|
||||
|
||||
def _extract_path_parameters(self, path: str) -> List[str]:
|
||||
"""Extract path parameters from a URL path.
|
||||
|
||||
Args:
|
||||
path: The URL path with path parameters
|
||||
|
||||
Returns:
|
||||
List of parameter names
|
||||
"""
|
||||
params = []
|
||||
segments = path.split("/")
|
||||
|
||||
for segment in segments:
|
||||
# Google APIs often use {param} format for path parameters
|
||||
if segment.startswith("{") and segment.endswith("}"):
|
||||
param_name = segment[1:-1]
|
||||
params.append(param_name)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_operation(
|
||||
self, method_data: Dict[str, Any], path_params: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a Google API method to an OpenAPI operation.
|
||||
|
||||
Args:
|
||||
method_data: Google API method data
|
||||
path_params: List of path parameter names
|
||||
|
||||
Returns:
|
||||
OpenAPI operation object
|
||||
"""
|
||||
operation = {
|
||||
"operationId": method_data.get("id", ""),
|
||||
"summary": method_data.get("description", ""),
|
||||
"description": method_data.get("description", ""),
|
||||
"parameters": [],
|
||||
"responses": {
|
||||
"200": {"description": "Successful operation"},
|
||||
"400": {"description": "Bad request"},
|
||||
"401": {"description": "Unauthorized"},
|
||||
"403": {"description": "Forbidden"},
|
||||
"404": {"description": "Not found"},
|
||||
"500": {"description": "Server error"},
|
||||
},
|
||||
}
|
||||
|
||||
# Add path parameters
|
||||
for param_name in path_params:
|
||||
param = {
|
||||
"name": param_name,
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
operation["parameters"].append(param)
|
||||
|
||||
# Add query parameters
|
||||
for param_name, param_data in method_data.get("parameters", {}).items():
|
||||
# Skip parameters already included in path
|
||||
if param_name in path_params:
|
||||
continue
|
||||
|
||||
param = {
|
||||
"name": param_name,
|
||||
"in": "query",
|
||||
"description": param_data.get("description", ""),
|
||||
"required": param_data.get("required", False),
|
||||
"schema": self._convert_parameter_schema(param_data),
|
||||
}
|
||||
operation["parameters"].append(param)
|
||||
|
||||
# Handle request body
|
||||
if "request" in method_data:
|
||||
request_ref = method_data.get("request", {}).get("$ref", "")
|
||||
if request_ref:
|
||||
if request_ref.startswith("#"):
|
||||
# Convert Google's reference format to OpenAPI format
|
||||
openapi_ref = request_ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
openapi_ref = "#/components/schemas/" + request_ref
|
||||
operation["requestBody"] = {
|
||||
"description": "Request body",
|
||||
"content": {"application/json": {"schema": {"$ref": openapi_ref}}},
|
||||
"required": True,
|
||||
}
|
||||
|
||||
# Handle response body
|
||||
if "response" in method_data:
|
||||
response_ref = method_data.get("response", {}).get("$ref", "")
|
||||
if response_ref:
|
||||
if response_ref.startswith("#"):
|
||||
# Convert Google's reference format to OpenAPI format
|
||||
openapi_ref = response_ref.replace("#", "#/components/schemas/")
|
||||
else:
|
||||
openapi_ref = "#/components/schemas/" + response_ref
|
||||
operation["responses"]["200"]["content"] = {
|
||||
"application/json": {"schema": {"$ref": openapi_ref}}
|
||||
}
|
||||
|
||||
# Add scopes if available
|
||||
scopes = method_data.get("scopes", [])
|
||||
if scopes:
|
||||
# Add method-specific security requirement if different from global
|
||||
operation["security"] = [{"oauth2": scopes}]
|
||||
|
||||
return operation
|
||||
|
||||
def _convert_parameter_schema(
|
||||
self, param_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a parameter definition to an OpenAPI schema.
|
||||
|
||||
Args:
|
||||
param_data: Google API parameter data
|
||||
|
||||
Returns:
|
||||
OpenAPI schema for the parameter
|
||||
"""
|
||||
schema = {}
|
||||
|
||||
# Convert type
|
||||
param_type = param_data.get("type", "string")
|
||||
schema["type"] = param_type
|
||||
|
||||
# Handle enum values
|
||||
if "enum" in param_data:
|
||||
schema["enum"] = param_data["enum"]
|
||||
|
||||
# Handle format
|
||||
if "format" in param_data:
|
||||
schema["format"] = param_data["format"]
|
||||
|
||||
# Handle default value
|
||||
if "default" in param_data:
|
||||
schema["default"] = param_data["default"]
|
||||
|
||||
# Handle pattern
|
||||
if "pattern" in param_data:
|
||||
schema["pattern"] = param_data["pattern"]
|
||||
|
||||
return schema
|
||||
|
||||
def save_openapi_spec(self, output_path: str) -> None:
|
||||
"""Save the OpenAPI specification to a file.
|
||||
|
||||
Args:
|
||||
output_path: Path where the OpenAPI spec should be saved
|
||||
"""
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.openapi_spec, f, indent=2)
|
||||
logger.info("OpenAPI specification saved to %s", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
"""Command line interface for the converter."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Convert Google API Discovery documents to OpenAPI v3 specifications"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"api_name", help="Name of the Google API (e.g., 'calendar')"
|
||||
)
|
||||
parser.add_argument("api_version", help="Version of the API (e.g., 'v3')")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="openapi_spec.json",
|
||||
help="Output file path for the OpenAPI specification",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Create and run the converter
|
||||
converter = GoogleApiToOpenApiConverter(args.api_name, args.api_version)
|
||||
converter.convert()
|
||||
converter.save_openapi_spec(args.output)
|
||||
print(
|
||||
f"Successfully converted {args.api_name} {args.api_version} to"
|
||||
" OpenAPI v3"
|
||||
)
|
||||
print(f"Output saved to {args.output}")
|
||||
except Exception as e:
|
||||
logger.error("Conversion failed: %s", e)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
class GoogleSearchTool(BaseTool):
|
||||
"""A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from Google Search.
|
||||
|
||||
This tool operates internally within the model and does not require or perform
|
||||
local code execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Name and description are not used because this is a model built-in tool.
|
||||
super().__init__(name='google_search', description='google_search')
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||
llm_request.config.tools = llm_request.config.tools or []
|
||||
if llm_request.model and llm_request.model.startswith('gemini-1'):
|
||||
if llm_request.config.tools:
|
||||
print(llm_request.config.tools)
|
||||
raise ValueError(
|
||||
'Google search tool can not be used with other tools in Gemini 1.x.'
|
||||
)
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
|
||||
)
|
||||
elif llm_request.model and llm_request.model.startswith('gemini-2'):
|
||||
llm_request.config.tools.append(
|
||||
types.Tool(google_search=types.GoogleSearch())
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Google search tool is not supported for model {llm_request.model}'
|
||||
)
|
||||
|
||||
|
||||
google_search = GoogleSearchTool()
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _automatic_function_calling_util
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
|
||||
class LangchainTool(FunctionTool):
|
||||
"""Use this class to wrap a langchain tool.
|
||||
|
||||
If the original tool name and description are not suitable, you can override
|
||||
them in the constructor.
|
||||
"""
|
||||
|
||||
tool: Any
|
||||
"""The wrapped langchain tool."""
|
||||
|
||||
def __init__(self, tool: Any):
|
||||
super().__init__(tool._run)
|
||||
self.tool = tool
|
||||
if tool.name:
|
||||
self.name = tool.name
|
||||
if tool.description:
|
||||
self.description = tool.description
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def populate_name(cls, data: Any) -> Any:
|
||||
# Override this to not use function's signature name as it's
|
||||
# mostly "run" or "invoke" for thir-party tools.
|
||||
return data
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
from langchain.agents import Tool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
# There are two types of tools:
|
||||
# 1. BaseTool: the tool is defined in langchain.tools.
|
||||
# 2. Other tools: the tool doesn't inherit any class but follow some
|
||||
# conventions, like having a "run" method.
|
||||
if isinstance(self.tool, BaseTool):
|
||||
tool_wrapper = Tool(
|
||||
name=self.name,
|
||||
func=self.func,
|
||||
description=self.description,
|
||||
)
|
||||
if self.tool.args_schema:
|
||||
tool_wrapper.args_schema = self.tool.args_schema
|
||||
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
|
||||
False,
|
||||
self.name,
|
||||
self.description,
|
||||
tool_wrapper.func,
|
||||
tool_wrapper.args,
|
||||
)
|
||||
return function_declaration
|
||||
else:
|
||||
# Need to provide a way to override the function names and descriptions
|
||||
# as the original function names are mostly ".run" and the descriptions
|
||||
# may not meet users' needs.
|
||||
function_declaration = (
|
||||
_automatic_function_calling_util.build_function_declaration(
|
||||
func=self.tool.run,
|
||||
)
|
||||
)
|
||||
return function_declaration
|
||||
@@ -0,0 +1,113 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_tool import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
class LoadArtifactsTool(BaseTool):
|
||||
"""A tool that loads the artifacts and adds them to the session."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name='load_artifacts',
|
||||
description='Loads the artifacts and adds them to the session.',
|
||||
)
|
||||
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'artifact_names': types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
artifact_names: list[str] = args.get('artifact_names', [])
|
||||
return {'artifact_names': artifact_names}
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context,
|
||||
llm_request=llm_request,
|
||||
)
|
||||
self._append_artifacts_to_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
def _append_artifacts_to_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
):
|
||||
artifact_names = tool_context.list_artifacts()
|
||||
if not artifact_names:
|
||||
return
|
||||
|
||||
# Tell the model about the available artifacts.
|
||||
llm_request.append_instructions([f"""You have a list of artifacts:
|
||||
{json.dumps(artifact_names)}
|
||||
|
||||
When the user asks questions about any of the artifacts, you should call the
|
||||
`load_artifacts` function to load the artifact. Do not generate any text other
|
||||
than the function call.
|
||||
"""])
|
||||
|
||||
# Attach the content of the artifacts if the model requests them.
|
||||
# This only adds the content to the model request, instead of the session.
|
||||
if llm_request.contents and llm_request.contents[-1].parts:
|
||||
function_response = llm_request.contents[-1].parts[0].function_response
|
||||
if function_response and function_response.name == 'load_artifacts':
|
||||
artifact_names = function_response.response['artifact_names']
|
||||
for artifact_name in artifact_names:
|
||||
artifact = tool_context.load_artifact(artifact_name)
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Artifact {artifact_name} is:'
|
||||
),
|
||||
artifact,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
load_artifacts_tool = LoadArtifactsTool()
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .function_tool import FunctionTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..memory.base_memory_service import MemoryResult
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
||||
"""Loads the memory for the current user.
|
||||
|
||||
Args:
|
||||
query: The query to load the memory for.
|
||||
|
||||
Returns:
|
||||
A list of memory results.
|
||||
"""
|
||||
response = tool_context.search_memory(query)
|
||||
return response.memories
|
||||
|
||||
|
||||
class LoadMemoryTool(FunctionTool):
|
||||
"""A tool that loads the memory for the current user."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(load_memory)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'query': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
*,
|
||||
tool_context: ToolContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
# Tell the model about the memory.
|
||||
llm_request.append_instructions(["""
|
||||
You have memory. You can use it to answer questions. If any questions need
|
||||
you to look up the memory, you should call load_memory function with a query.
|
||||
"""])
|
||||
|
||||
|
||||
load_memory_tool = LoadMemoryTool()
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tool for web browse."""
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def load_web_page(url: str) -> str:
|
||||
"""Fetches the content in the url and returns the text in it.
|
||||
|
||||
Args:
|
||||
url (str): The url to browse.
|
||||
|
||||
Returns:
|
||||
str: The text content of the url.
|
||||
"""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, 'lxml')
|
||||
text = soup.get_text(separator='\n', strip=True)
|
||||
else:
|
||||
text = f'Failed to fetch url: {url}'
|
||||
|
||||
# Split the text into lines, filtering out very short lines
|
||||
# (e.g., single words or short subtitles)
|
||||
return '\n'.join(line for line in text.splitlines() if len(line.split()) > 3)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from .function_tool import FunctionTool
|
||||
|
||||
|
||||
class LongRunningFunctionTool(FunctionTool):
|
||||
"""A function tool that returns the result asynchronously.
|
||||
|
||||
This tool is used for long-running operations that may take a significant
|
||||
amount of time to complete. The framework will call the function. Once the
|
||||
function returns, the response will be returned asynchronously to the
|
||||
framework which is identified by the function_call_id.
|
||||
|
||||
Example:
|
||||
```python
|
||||
tool = LongRunningFunctionTool(a_long_running_function)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
is_long_running: Whether the tool is a long running operation.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
super().__init__(func)
|
||||
self.is_long_running = True
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = []
|
||||
|
||||
try:
|
||||
from .conversion_utils import adk_to_mcp_tool_type, gemini_to_json_schema
|
||||
from .mcp_tool import MCPTool
|
||||
from .mcp_toolset import MCPToolset
|
||||
|
||||
__all__.extend([
|
||||
'adk_to_mcp_tool_type',
|
||||
'gemini_to_json_schema',
|
||||
'MCPTool',
|
||||
'MCPToolset',
|
||||
])
|
||||
|
||||
except ImportError as e:
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
logger.warning(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
)
|
||||
else:
|
||||
logger.debug('MCP Tool is not installed')
|
||||
logger.debug(e)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,161 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
from google.genai.types import Schema, Type
|
||||
import mcp.types as mcp_types
|
||||
from ..base_tool import BaseTool
|
||||
|
||||
|
||||
def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
|
||||
"""Convert a Tool in ADK into MCP tool type.
|
||||
|
||||
This function transforms an ADK tool definition into its equivalent
|
||||
representation in the MCP (Model Control Plane) system.
|
||||
|
||||
Args:
|
||||
tool: The ADK tool to convert. It should be an instance of a class derived
|
||||
from `BaseTool`.
|
||||
|
||||
Returns:
|
||||
An object of MCP Tool type, representing the converted tool.
|
||||
|
||||
Examples:
|
||||
# Assuming 'my_tool' is an instance of a BaseTool derived class
|
||||
mcp_tool = adk_to_mcp_tool_type(my_tool)
|
||||
print(mcp_tool)
|
||||
"""
|
||||
tool_declaration = tool._get_declaration()
|
||||
if not tool_declaration:
|
||||
input_schema = {}
|
||||
else:
|
||||
input_schema = gemini_to_json_schema(tool._get_declaration().parameters)
|
||||
return mcp_types.Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=input_schema,
|
||||
)
|
||||
|
||||
|
||||
def gemini_to_json_schema(gemini_schema: Schema) -> Dict[str, Any]:
|
||||
"""Converts a Gemini Schema object into a JSON Schema dictionary.
|
||||
|
||||
Args:
|
||||
gemini_schema: An instance of the Gemini Schema class.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the equivalent JSON Schema.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input is not an instance of the expected Schema class.
|
||||
ValueError: If an invalid Gemini Type enum value is encountered.
|
||||
"""
|
||||
if not isinstance(gemini_schema, Schema):
|
||||
raise TypeError(
|
||||
f"Input must be an instance of Schema, got {type(gemini_schema)}"
|
||||
)
|
||||
|
||||
json_schema_dict: Dict[str, Any] = {}
|
||||
|
||||
# Map Type
|
||||
gemini_type = getattr(gemini_schema, "type", None)
|
||||
if gemini_type and gemini_type != Type.TYPE_UNSPECIFIED:
|
||||
json_schema_dict["type"] = gemini_type.lower()
|
||||
else:
|
||||
json_schema_dict["type"] = "null"
|
||||
|
||||
# Map Nullable
|
||||
if getattr(gemini_schema, "nullable", None) == True:
|
||||
json_schema_dict["nullable"] = True
|
||||
|
||||
# --- Map direct fields ---
|
||||
direct_mappings = {
|
||||
"title": "title",
|
||||
"description": "description",
|
||||
"default": "default",
|
||||
"enum": "enum",
|
||||
"format": "format",
|
||||
"example": "example",
|
||||
}
|
||||
for gemini_key, json_key in direct_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# String validation
|
||||
if gemini_type == Type.STRING:
|
||||
str_mappings = {
|
||||
"pattern": "pattern",
|
||||
"min_length": "minLength",
|
||||
"max_length": "maxLength",
|
||||
}
|
||||
for gemini_key, json_key in str_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Number/Integer validation
|
||||
if gemini_type in (Type.NUMBER, Type.INTEGER):
|
||||
num_mappings = {
|
||||
"minimum": "minimum",
|
||||
"maximum": "maximum",
|
||||
}
|
||||
for gemini_key, json_key in num_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Array validation (Recursive call for items)
|
||||
if gemini_type == Type.ARRAY:
|
||||
items_schema = getattr(gemini_schema, "items", None)
|
||||
if items_schema is not None:
|
||||
json_schema_dict["items"] = gemini_to_json_schema(items_schema)
|
||||
|
||||
arr_mappings = {
|
||||
"min_items": "minItems",
|
||||
"max_items": "maxItems",
|
||||
}
|
||||
for gemini_key, json_key in arr_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Object validation (Recursive call for properties)
|
||||
if gemini_type == Type.OBJECT:
|
||||
properties_dict = getattr(gemini_schema, "properties", None)
|
||||
if properties_dict is not None:
|
||||
json_schema_dict["properties"] = {
|
||||
prop_name: gemini_to_json_schema(prop_schema)
|
||||
for prop_name, prop_schema in properties_dict.items()
|
||||
}
|
||||
|
||||
obj_mappings = {
|
||||
"required": "required",
|
||||
"min_properties": "minProperties",
|
||||
"max_properties": "maxProperties",
|
||||
# Note: Ignoring 'property_ordering' as it's not standard JSON Schema
|
||||
}
|
||||
for gemini_key, json_key in obj_mappings.items():
|
||||
value = getattr(gemini_schema, gemini_key, None)
|
||||
if value is not None:
|
||||
json_schema_dict[json_key] = value
|
||||
|
||||
# Map anyOf (Recursive call for subschemas)
|
||||
any_of_list = getattr(gemini_schema, "any_of", None)
|
||||
if any_of_list is not None:
|
||||
json_schema_dict["anyOf"] = [
|
||||
gemini_to_json_schema(sub_schema) for sub_schema in any_of_list
|
||||
]
|
||||
|
||||
return json_schema_dict
|
||||
@@ -0,0 +1,176 @@
|
||||
from contextlib import AsyncExitStack
|
||||
import functools
|
||||
import sys
|
||||
from typing import Any, TextIO
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
"""Parameters for the MCP SSE connection.
|
||||
|
||||
See MCP SSE Client documentation for more details.
|
||||
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
||||
"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
def retry_on_closed_resource(async_reinit_func_name: str):
|
||||
"""Decorator to automatically reinitialize session and retry action.
|
||||
|
||||
When MCP session was closed, the decorator will automatically recreate the
|
||||
session and retry the action with the same parameters.
|
||||
|
||||
Note:
|
||||
1. async_reinit_func_name is the name of the class member function that
|
||||
reinitializes the MCP session.
|
||||
2. Both the decorated function and the async_reinit_func_name must be async
|
||||
functions.
|
||||
|
||||
Usage:
|
||||
class MCPTool:
|
||||
...
|
||||
async def create_session(self):
|
||||
self.session = ...
|
||||
|
||||
@retry_on_closed_resource('create_session')
|
||||
async def use_session(self):
|
||||
await self.session.call_tool()
|
||||
|
||||
Args:
|
||||
async_reinit_func_name: The name of the async function to recreate session.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(
|
||||
func
|
||||
) # Preserves original function metadata (name, docstring)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except anyio.ClosedResourceError:
|
||||
try:
|
||||
if hasattr(self, async_reinit_func_name) and callable(
|
||||
getattr(self, async_reinit_func_name)
|
||||
):
|
||||
async_init_fn = getattr(self, async_reinit_func_name)
|
||||
await async_init_fn()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Function {async_reinit_func_name} does not exist in decorated'
|
||||
' class. Please check the function name in'
|
||||
' retry_on_closed_resource decorator.'
|
||||
)
|
||||
except Exception as reinit_err:
|
||||
raise RuntimeError(
|
||||
f'Error reinitializing: {reinit_err}'
|
||||
) from reinit_err
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MCPSessionManager:
|
||||
"""Manages MCP client sessions.
|
||||
|
||||
This class provides methods for creating and initializing MCP client sessions,
|
||||
handling different connection parameters (Stdio and SSE).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes the MCP session manager.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
mcp_session_manager = MCPSessionManager(
|
||||
connection_params=connection_params,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
session = await mcp_session_manager.create_session()
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
"""
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = exit_stack
|
||||
self.errlog = errlog
|
||||
|
||||
async def create_session(self) -> ClientSession:
|
||||
return await MCPSessionManager.initialize_session(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(
|
||||
cls,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes an MCP client session.
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The initialized MCP client session.
|
||||
"""
|
||||
if isinstance(connection_params, StdioServerParameters):
|
||||
client = stdio_client(server=connection_params, errlog=errlog)
|
||||
elif isinstance(connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=connection_params.url,
|
||||
headers=connection_params.headers,
|
||||
timeout=connection_params.timeout,
|
||||
sse_read_timeout=connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {connection_params}'
|
||||
)
|
||||
|
||||
transports = await exit_stack.enter_async_context(client)
|
||||
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
||||
await session.initialize()
|
||||
return session
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.types import Tool as McpBaseTool
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
|
||||
" version."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
from ..base_tool import BaseTool
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
|
||||
|
||||
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
|
||||
call the tool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_tool: McpBaseTool,
|
||||
mcp_session: ClientSession,
|
||||
mcp_session_manager: MCPSessionManager,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] | None = None,
|
||||
):
|
||||
"""Initializes a MCPTool.
|
||||
|
||||
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
|
||||
the MCP Tool through executing the tool from remote MCP Session.
|
||||
|
||||
Example:
|
||||
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
|
||||
|
||||
Args:
|
||||
mcp_tool: The MCP tool to wrap.
|
||||
mcp_session: The MCP session to use to call the tool.
|
||||
auth_scheme: The authentication scheme to use.
|
||||
auth_credential: The authentication credential to use.
|
||||
|
||||
Raises:
|
||||
ValueError: If mcp_tool or mcp_session is None.
|
||||
"""
|
||||
if mcp_tool is None:
|
||||
raise ValueError("mcp_tool cannot be None")
|
||||
if mcp_session is None:
|
||||
raise ValueError("mcp_session cannot be None")
|
||||
self.name = mcp_tool.name
|
||||
self.description = mcp_tool.description if mcp_tool.description else ""
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_session = mcp_session
|
||||
self.mcp_session_manager = mcp_session_manager
|
||||
# TODO(cheliu): Support passing auth to MCP Server.
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
async def _reinitialize_session(self):
|
||||
self.mcp_session = await self.mcp_session_manager.create_session()
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Gets the function declaration for the tool.
|
||||
|
||||
Returns:
|
||||
FunctionDeclaration: The Gemini function declaration for the tool.
|
||||
"""
|
||||
schema_dict = self.mcp_tool.inputSchema
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
@retry_on_closed_resource("_reinitialize_session")
|
||||
async def run_async(self, *, args, tool_context: ToolContext):
|
||||
"""Runs the tool asynchronously.
|
||||
|
||||
Args:
|
||||
args: The arguments as a dict to pass to the tool.
|
||||
tool_context: The tool context from upper level ADK agent.
|
||||
|
||||
Returns:
|
||||
Any: The response from the tool.
|
||||
"""
|
||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||
try:
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
@@ -0,0 +1,266 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import List, Optional, TextIO, Tuple, Type
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.types import ListToolsResult
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
from .mcp_tool import MCPTool
|
||||
|
||||
|
||||
class MCPToolset:
|
||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||
|
||||
Usage:
|
||||
Example 1: (using from_server helper):
|
||||
```
|
||||
async def load_tools():
|
||||
return await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
|
||||
# Use the tools in an LLM agent
|
||||
tools, exit_stack = await load_tools()
|
||||
agent = LlmAgent(
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await exit_stack.aclose()
|
||||
```
|
||||
|
||||
Example 2: (using `async with`):
|
||||
|
||||
```
|
||||
async def load_tools():
|
||||
async with MCPToolset(
|
||||
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
|
||||
) as toolset:
|
||||
tools = await toolset.load_tools()
|
||||
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
Example 3: (provide AsyncExitStack):
|
||||
```
|
||||
async def load_tools():
|
||||
async_exit_stack = AsyncExitStack()
|
||||
toolset = MCPToolset(
|
||||
connection_params=StdioServerParameters(...),
|
||||
)
|
||||
async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await async_exit_stack.aclose()
|
||||
|
||||
```
|
||||
|
||||
Attributes:
|
||||
connection_params: The connection parameters to the MCP server. Can be
|
||||
either `StdioServerParameters` or `SseServerParams`.
|
||||
exit_stack: The async exit stack to manage the connection to the MCP server.
|
||||
session: The MCP session being initialized with the connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
errlog: TextIO = sys.stderr,
|
||||
exit_stack=AsyncExitStack(),
|
||||
):
|
||||
"""Initializes the MCPToolset.
|
||||
|
||||
Usage:
|
||||
Example 1: (using from_server helper):
|
||||
```
|
||||
async def load_tools():
|
||||
return await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
|
||||
# Use the tools in an LLM agent
|
||||
tools, exit_stack = await load_tools()
|
||||
agent = LlmAgent(
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await exit_stack.aclose()
|
||||
```
|
||||
|
||||
Example 2: (using `async with`):
|
||||
|
||||
```
|
||||
async def load_tools():
|
||||
async with MCPToolset(
|
||||
connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
|
||||
) as toolset:
|
||||
tools = await toolset.load_tools()
|
||||
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
Example 3: (provide AsyncExitStack):
|
||||
```
|
||||
async def load_tools():
|
||||
async_exit_stack = AsyncExitStack()
|
||||
toolset = MCPToolset(
|
||||
connection_params=StdioServerParameters(...),
|
||||
)
|
||||
async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
agent = LlmAgent(
|
||||
...
|
||||
tools=tools
|
||||
)
|
||||
...
|
||||
await async_exit_stack.aclose()
|
||||
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server. Can be:
|
||||
`StdioServerParameters` for using local mcp server (e.g. using `npx` or
|
||||
`python3`); or `SseServerParams` for a local/remote SSE server.
|
||||
"""
|
||||
if not connection_params:
|
||||
raise ValueError('Missing connection params in MCPToolset.')
|
||||
self.connection_params = connection_params
|
||||
self.errlog = errlog
|
||||
self.exit_stack = exit_stack
|
||||
|
||||
self.session_manager = MCPSessionManager(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_server(
|
||||
cls,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
async_exit_stack: Optional[AsyncExitStack] = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
||||
"""Retrieve all tools from the MCP connection.
|
||||
|
||||
Usage:
|
||||
```
|
||||
async def load_tools():
|
||||
tools, exit_stack = await MCPToolset.from_server(
|
||||
connection_params=StdioServerParameters(
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server.
|
||||
async_exit_stack: The async exit stack to use. If not provided, a new
|
||||
AsyncExitStack will be created.
|
||||
|
||||
Returns:
|
||||
A tuple of the list of MCPTools and the AsyncExitStack.
|
||||
- tools: The list of MCPTools.
|
||||
- async_exit_stack: The AsyncExitStack used to manage the connection to
|
||||
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
||||
connection when server shuts down.
|
||||
"""
|
||||
async_exit_stack = async_exit_stack or AsyncExitStack()
|
||||
toolset = cls(
|
||||
connection_params=connection_params,
|
||||
exit_stack=async_exit_stack,
|
||||
errlog=errlog,
|
||||
)
|
||||
|
||||
await async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
return (tools, async_exit_stack)
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||
self.session = await self.session_manager.create_session()
|
||||
return self.session
|
||||
|
||||
async def _exit(self):
|
||||
"""Closes the connection to MCP Server."""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
@retry_on_closed_resource('_initialize')
|
||||
async def load_tools(self) -> List[MCPTool]:
|
||||
"""Loads all tools from the MCP Server.
|
||||
|
||||
Returns:
|
||||
A list of MCPTools imported from the MCP Server.
|
||||
"""
|
||||
tools_response: ListToolsResult = await self.session.list_tools()
|
||||
return [
|
||||
MCPTool(
|
||||
mcp_tool=tool,
|
||||
mcp_session=self.session,
|
||||
mcp_session_manager=self.session_manager,
|
||||
)
|
||||
for tool in tools_response.tools
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
try:
|
||||
await self._initialize()
|
||||
return self
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self._exit()
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenAPIToolset
|
||||
from .openapi_spec_parser import RestApiTool
|
||||
|
||||
__all__ = [
|
||||
'OpenAPIToolset',
|
||||
'RestApiTool',
|
||||
]
|
||||
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import auth_helpers
|
||||
|
||||
__all__ = [
|
||||
'auth_helpers',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,498 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
import requests
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_credential import HttpAuth
|
||||
from ....auth.auth_credential import HttpCredentials
|
||||
from ....auth.auth_credential import OAuth2Auth
|
||||
from ....auth.auth_credential import ServiceAccount
|
||||
from ....auth.auth_credential import ServiceAccountCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from ..common.common import ApiParameter
|
||||
|
||||
|
||||
class OpenIdConfig(BaseModel):
|
||||
"""Represents OpenID Connect configuration.
|
||||
|
||||
Attributes:
|
||||
client_id: The client ID.
|
||||
auth_uri: The authorization URI.
|
||||
token_uri: The token URI.
|
||||
client_secret: The client secret.
|
||||
|
||||
Example:
|
||||
config = OpenIdConfig(
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_secret="your_client_secret",
|
||||
redirect
|
||||
)
|
||||
"""
|
||||
|
||||
client_id: str
|
||||
auth_uri: str
|
||||
token_uri: str
|
||||
client_secret: str
|
||||
redirect_uri: Optional[str]
|
||||
|
||||
|
||||
def token_to_scheme_credential(
|
||||
token_type: Literal["apikey", "oauth2Token"],
|
||||
location: Optional[Literal["header", "query", "cookie"]] = None,
|
||||
name: Optional[str] = None,
|
||||
credential_value: Optional[str] = None,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates a AuthScheme and AuthCredential for API key or bearer token.
|
||||
|
||||
Examples:
|
||||
```
|
||||
# API Key in header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
|
||||
"X-API-Key", "your_api_key_value")
|
||||
|
||||
# API Key in query parameter
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
|
||||
"api_key", "your_api_key_value")
|
||||
|
||||
# OAuth2 Bearer Token in Authorization header
|
||||
auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
|
||||
"header", "Authorization", "your_bearer_token_value")
|
||||
```
|
||||
|
||||
Args:
|
||||
type: 'apikey' or 'oauth2Token'.
|
||||
location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
|
||||
name: The name of the header, query parameter, or cookie.
|
||||
credential_value: The value of the API Key/ Token.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid type or location.
|
||||
"""
|
||||
if token_type == "apikey":
|
||||
in_: APIKeyIn
|
||||
if location == "header":
|
||||
in_ = APIKeyIn.header
|
||||
elif location == "query":
|
||||
in_ = APIKeyIn.query
|
||||
elif location == "cookie":
|
||||
in_ = APIKeyIn.cookie
|
||||
else:
|
||||
raise ValueError(f"Invalid location for apiKey: {location}")
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": in_,
|
||||
"name": name,
|
||||
})
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
elif token_type == "oauth2Token":
|
||||
# ignore location. OAuth2 Bearer Token is always in Authorization header.
|
||||
auth_scheme = HTTPBearer(
|
||||
bearerFormat="JWT"
|
||||
) # Common format, can be omitted.
|
||||
if credential_value:
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credential_value),
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_credential = None
|
||||
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {type}")
|
||||
|
||||
|
||||
def service_account_dict_to_scheme_credential(
|
||||
config: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
scopes: A list of scopes to be used.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
service_account = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential.model_construct(
|
||||
**config
|
||||
),
|
||||
scopes=scopes,
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def service_account_scheme_credential(
|
||||
config: ServiceAccount,
|
||||
) -> Tuple[AuthScheme, AuthCredential]:
|
||||
"""Creates AuthScheme and AuthCredential for Google Service Account.
|
||||
|
||||
Returns a bearer token scheme, and a service account credential.
|
||||
|
||||
Args:
|
||||
config: A ServiceAccount object containing the Google Service Account
|
||||
configuration.
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
"""
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
|
||||
)
|
||||
return auth_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_dict_to_scheme_credential(
|
||||
config_dict: Dict[str, Any],
|
||||
scopes: List[str],
|
||||
credential_dict: Dict[str, Any],
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from configuration and credential dictionaries.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary containing OpenID Connect configuration, must
|
||||
include at least 'authorization_endpoint' and 'token_endpoint'.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include 'client_id', 'client_secret', and 'scopes'. May optionally
|
||||
include 'redirect_uri'.
|
||||
|
||||
Returns:
|
||||
Tuple: (OpenIdConnectWithConfig, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing in the input dictionaries.
|
||||
"""
|
||||
|
||||
# Validate and create the OpenIdConnectWithConfig scheme
|
||||
try:
|
||||
config_dict["scopes"] = scopes
|
||||
# If user provides the OpenID Config as a static dict, it may not contain
|
||||
# openIdConnect URL.
|
||||
if "openIdConnectUrl" not in config_dict:
|
||||
config_dict["openIdConnectUrl"] = ""
|
||||
openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
|
||||
|
||||
# Attempt to adjust credential_dict if this is a key downloaded from Google
|
||||
# OAuth config
|
||||
if len(list(credential_dict.values())) == 1:
|
||||
credential_value = list(credential_dict.values())[0]
|
||||
if "client_id" in credential_value and "client_secret" in credential_value:
|
||||
credential_dict = credential_value
|
||||
|
||||
# Validate credential_dict
|
||||
required_credential_fields = ["client_id", "client_secret"]
|
||||
missing_fields = [
|
||||
field
|
||||
for field in required_credential_fields
|
||||
if field not in credential_dict
|
||||
]
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
"Missing required fields in credential_dict:"
|
||||
f" {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
# Construct AuthCredential
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=credential_dict["client_id"],
|
||||
client_secret=credential_dict["client_secret"],
|
||||
redirect_uri=credential_dict.get("redirect_uri", None),
|
||||
),
|
||||
)
|
||||
|
||||
return openid_scheme, auth_credential
|
||||
|
||||
|
||||
def openid_url_to_scheme_credential(
|
||||
openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
|
||||
) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
|
||||
"""Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
|
||||
|
||||
Fetches OpenID configuration from the provided URL.
|
||||
|
||||
Args:
|
||||
openid_url: The OpenID Connect discovery URL.
|
||||
scopes: List of scopes to be used.
|
||||
credential_dict: Dictionary containing credential information, must
|
||||
include at least "client_id" and "client_secret", may optionally include
|
||||
"redirect_uri" and "scope"
|
||||
|
||||
Returns:
|
||||
Tuple: (AuthScheme, AuthCredential)
|
||||
|
||||
Raises:
|
||||
ValueError: If the OpenID URL is invalid, fetching fails, or required
|
||||
fields are missing.
|
||||
requests.exceptions.RequestException: If there's an error during the
|
||||
HTTP request.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(openid_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
config_dict = response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Failed to fetch OpenID configuration from {openid_url}: {e}"
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Invalid JSON response from OpenID configuration endpoint"
|
||||
f" {openid_url}: {e}"
|
||||
) from e
|
||||
|
||||
# Add openIdConnectUrl to config dict
|
||||
config_dict["openIdConnectUrl"] = openid_url
|
||||
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
|
||||
|
||||
|
||||
def credential_to_param(
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
|
||||
"""Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
|
||||
|
||||
This function now supports all credential types returned by the exchangers:
|
||||
- API Key
|
||||
- HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
|
||||
- OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
|
||||
token)
|
||||
- Service Account (returns None, None, as the token is now a Bearer token)
|
||||
|
||||
Args:
|
||||
auth_scheme: The AuthScheme object.
|
||||
auth_credential: The AuthCredential object.
|
||||
|
||||
Returns:
|
||||
Tuple: (ApiParameter, Dict[str, Any])
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None, None
|
||||
|
||||
if (
|
||||
auth_scheme.type_ == AuthSchemeType.apiKey
|
||||
and auth_credential
|
||||
and auth_credential.api_key
|
||||
):
|
||||
param_name = auth_scheme.name or ""
|
||||
python_name = INTERNAL_AUTH_PREFIX + param_name
|
||||
if auth_scheme.in_ == APIKeyIn.header:
|
||||
param_location = "header"
|
||||
elif auth_scheme.in_ == APIKeyIn.query:
|
||||
param_location = "query"
|
||||
elif auth_scheme.in_ == APIKeyIn.cookie:
|
||||
param_location = "cookie"
|
||||
else:
|
||||
raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
|
||||
|
||||
param = ApiParameter(
|
||||
original_name=param_name,
|
||||
param_location=param_location,
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "",
|
||||
py_name=python_name,
|
||||
)
|
||||
kwargs = {param.py_name: auth_credential.api_key}
|
||||
return param, kwargs
|
||||
|
||||
# TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
|
||||
# Scheme
|
||||
elif (
|
||||
auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
):
|
||||
if (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
elif (
|
||||
auth_credential
|
||||
and auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and (
|
||||
auth_credential.http.credentials.username
|
||||
or auth_credential.http.credentials.password
|
||||
)
|
||||
):
|
||||
# Basic Auth is explicitly NOT supported
|
||||
raise NotImplementedError("Basic Authentication is not supported.")
|
||||
else:
|
||||
raise ValueError("Invalid HTTP auth credentials")
|
||||
|
||||
# Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
|
||||
# Bearer tokens.
|
||||
elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
|
||||
auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
|
||||
):
|
||||
if (
|
||||
auth_credential.http
|
||||
and auth_credential.http.credentials
|
||||
and auth_credential.http.credentials.token
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name="Authorization",
|
||||
param_location="header",
|
||||
param_schema=Schema(type="string"),
|
||||
description=auth_scheme.description or "Bearer token",
|
||||
py_name=INTERNAL_AUTH_PREFIX + "Authorization",
|
||||
)
|
||||
kwargs = {
|
||||
param.py_name: f"Bearer {auth_credential.http.credentials.token}"
|
||||
}
|
||||
return param, kwargs
|
||||
return None, None
|
||||
else:
|
||||
raise ValueError("Invalid security scheme and credential combination")
|
||||
|
||||
|
||||
def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
|
||||
"""Converts a dictionary to a FastAPI AuthScheme object.
|
||||
|
||||
Args:
|
||||
data: The dictionary representing the security scheme.
|
||||
|
||||
Returns:
|
||||
A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
|
||||
HTTPBearer).
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'type' field is missing or invalid, or if the
|
||||
dictionary cannot be converted to the corresponding Pydantic model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
api_key_data = {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key",
|
||||
}
|
||||
api_key_scheme = dict_to_auth_scheme(api_key_data)
|
||||
|
||||
bearer_data = {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
}
|
||||
bearer_scheme = dict_to_auth_scheme(bearer_data)
|
||||
|
||||
|
||||
oauth2_data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
}
|
||||
}
|
||||
oauth2_scheme = dict_to_auth_scheme(oauth2_data)
|
||||
|
||||
openid_data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
|
||||
}
|
||||
openid_scheme = dict_to_auth_scheme(openid_data)
|
||||
|
||||
|
||||
```
|
||||
"""
|
||||
if "type" not in data:
|
||||
raise ValueError("Missing 'type' field in security scheme dictionary.")
|
||||
|
||||
security_type = data["type"]
|
||||
try:
|
||||
if security_type == "apiKey":
|
||||
return APIKey.model_validate(data)
|
||||
elif security_type == "http":
|
||||
if data.get("scheme") == "bearer":
|
||||
return HTTPBearer.model_validate(data)
|
||||
else:
|
||||
return HTTPBase.model_validate(data) # Generic HTTP
|
||||
elif security_type == "oauth2":
|
||||
return OAuth2.model_validate(data)
|
||||
elif security_type == "openIdConnect":
|
||||
return OpenIdConnect.model_validate(data)
|
||||
else:
|
||||
raise ValueError(f"Invalid security scheme type: {security_type}")
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid security scheme data: {e}") from e
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
__all__ = [
|
||||
'AutoAuthCredentialExchanger',
|
||||
'BaseAuthCredentialExchanger',
|
||||
'OAuth2CredentialExchanger',
|
||||
'ServiceAccountCredentialExchanger',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from .oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from .service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
|
||||
|
||||
class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Automatically selects the appropriate credential exchanger based on the auth scheme.
|
||||
|
||||
Optionally, an override can be provided to use a specific exchanger for a
|
||||
given auth scheme.
|
||||
|
||||
Example (common case):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger()
|
||||
auth_credential = exchanger.exchange_credential(
|
||||
auth_scheme=service_account_scheme,
|
||||
auth_credential=service_account_credential,
|
||||
)
|
||||
# Returns an oauth token in the form of a bearer token.
|
||||
```
|
||||
|
||||
Example (use CustomAuthExchanger for OAuth2):
|
||||
```
|
||||
exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers={
|
||||
AuthScheme.OAUTH2: CustomAuthExchanger,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
exchangers: A dictionary mapping auth scheme to credential exchanger class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
custom_exchangers: Optional[
|
||||
Dict[str, Type[BaseAuthCredentialExchanger]]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes the AutoAuthCredentialExchanger.
|
||||
|
||||
Args:
|
||||
custom_exchangers: Optional dictionary for adding or overriding auth
|
||||
exchangers. The key is the auth scheme, and the value is the credential
|
||||
exchanger class.
|
||||
"""
|
||||
self.exchangers = {
|
||||
AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
|
||||
AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
|
||||
}
|
||||
|
||||
if custom_exchangers:
|
||||
self.exchangers.update(custom_exchangers)
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Automatically exchanges for the credential uses the appropriate credential exchanger.
|
||||
|
||||
Args:
|
||||
auth_scheme (AuthScheme): The security scheme.
|
||||
auth_credential (AuthCredential): Optional. The authentication
|
||||
credential.
|
||||
|
||||
Returns: (AuthCredential)
|
||||
A new AuthCredential object containing the exchanged credential.
|
||||
|
||||
"""
|
||||
if not auth_credential:
|
||||
return None
|
||||
|
||||
exchanger_class = self.exchangers.get(
|
||||
auth_credential.auth_type if auth_credential else None
|
||||
)
|
||||
|
||||
if not exchanger_class:
|
||||
return auth_credential
|
||||
|
||||
exchanger = exchanger_class()
|
||||
return exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
|
||||
|
||||
class AuthCredentialMissingError(Exception):
|
||||
"""Exception raised when required authentication credentials are missing."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BaseAuthCredentialExchanger:
|
||||
"""Base class for authentication credential exchangers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the provided authentication credential for a usable token/credential.
|
||||
|
||||
Args:
|
||||
auth_scheme: The security scheme.
|
||||
auth_credential: The authentication credential.
|
||||
|
||||
Returns:
|
||||
An updated AuthCredential object containing the fetched credential.
|
||||
For simple schemes like API key, it may return the original credential
|
||||
if no exchange is needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement exchange_credential.")
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for OpenID Connect."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_credential import HttpAuth
|
||||
from .....auth.auth_credential import HttpCredentials
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .....auth.auth_schemes import AuthSchemeType
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for OAuth2 and OpenID Connect."""
|
||||
|
||||
def _check_scheme_credential_type(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
if not auth_credential:
|
||||
raise ValueError(
|
||||
"auth_credential is empty. Please create AuthCredential using"
|
||||
" OAuth2Auth."
|
||||
)
|
||||
|
||||
if auth_scheme.type_ not in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid security scheme, expect AuthSchemeType.openIdConnect or "
|
||||
f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
|
||||
)
|
||||
|
||||
if not auth_credential.oauth2 and not auth_credential.http:
|
||||
raise ValueError(
|
||||
"auth_credential is not configured with oauth2. Please"
|
||||
" create AuthCredential and set OAuth2Auth."
|
||||
)
|
||||
|
||||
def generate_auth_token(
|
||||
self,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Args:
|
||||
auth_scheme: The OpenID Connect or OAuth2 auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP bearer access token. If the
|
||||
HTTP bearer token cannot be generated, return the original credential.
|
||||
"""
|
||||
|
||||
if not auth_credential.oauth2.access_token:
|
||||
return auth_credential
|
||||
|
||||
# Return the access token as a bearer token.
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(
|
||||
token=auth_credential.oauth2.access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the OpenID Connect auth credential for an access token or an auth URI.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the HTTP Bearer access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the auth scheme or auth credential is invalid.
|
||||
"""
|
||||
# TODO(cheliu): Implement token refresh flow
|
||||
|
||||
self._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
# If token is already HTTPBearer token, do nothing assuming that this token
|
||||
# is valid.
|
||||
if auth_credential.http:
|
||||
return auth_credential
|
||||
|
||||
# If access token is exchanged, exchange a HTTPBearer token.
|
||||
if auth_credential.oauth2.access_token:
|
||||
return self.generate_auth_token(auth_credential)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Credential fetcher for Google Service Account."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import google.oauth2.credentials
|
||||
|
||||
from .....auth.auth_credential import (
|
||||
AuthCredential,
|
||||
AuthCredentialTypes,
|
||||
HttpAuth,
|
||||
HttpCredentials,
|
||||
)
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
|
||||
|
||||
|
||||
class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Fetches credentials for Google Service Account.
|
||||
|
||||
Uses the default service credential if `use_default_credential = True`.
|
||||
Otherwise, uses the service account credential provided in the auth
|
||||
credential.
|
||||
"""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the service account auth credential for an access token.
|
||||
|
||||
If auth_credential contains a service account credential, it will be used
|
||||
to fetch an access token. Otherwise, the default service credential will be
|
||||
used for fetching an access token.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format, containing the access token.
|
||||
"""
|
||||
if (
|
||||
auth_credential is None
|
||||
or auth_credential.service_account is None
|
||||
or (
|
||||
auth_credential.service_account.service_account_credential is None
|
||||
and not auth_credential.service_account.use_default_credential
|
||||
)
|
||||
):
|
||||
raise AuthCredentialMissingError(
|
||||
"Service account credentials are missing. Please provide them, or set"
|
||||
" `use_default_credential = True` to use application default"
|
||||
" credential in a hosted service like Cloud Run."
|
||||
)
|
||||
|
||||
try:
|
||||
if auth_credential.service_account.use_default_credential:
|
||||
credentials, _ = google.auth.default()
|
||||
else:
|
||||
config = auth_credential.service_account
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
config.service_account_credential.model_dump(), scopes=config.scopes
|
||||
)
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credentials.token),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
except Exception as e:
|
||||
raise AuthCredentialMissingError(
|
||||
f"Failed to exchange service account token: {e}"
|
||||
) from e
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import common
|
||||
|
||||
__all__ = [
|
||||
'common',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,300 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import keyword
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_serializer
|
||||
|
||||
|
||||
def to_snake_case(text: str) -> str:
|
||||
"""Converts a string into snake_case.
|
||||
|
||||
Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
|
||||
(e.g., "REST API") and consecutive uppercase letters correctly. Also handles
|
||||
mixed cases with and without spaces.
|
||||
|
||||
Examples:
|
||||
```
|
||||
to_snake_case('camelCase') -> 'camel_case'
|
||||
to_snake_case('UpperCamelCase') -> 'upper_camel_case'
|
||||
to_snake_case('space separated') -> 'space_separated'
|
||||
```
|
||||
|
||||
Args:
|
||||
text: The input string.
|
||||
|
||||
Returns:
|
||||
The snake_case version of the string.
|
||||
"""
|
||||
|
||||
# Handle spaces and non-alphanumeric characters (replace with underscores)
|
||||
text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
|
||||
|
||||
# Insert underscores before uppercase letters (handling both CamelCases)
|
||||
text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
|
||||
text = re.sub(
|
||||
r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
|
||||
) # UpperCamelCase and acronyms
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove consecutive underscores (clean up extra underscores)
|
||||
text = re.sub(r'_+', '_', text)
|
||||
|
||||
# Remove leading and trailing underscores
|
||||
text = text.strip('_')
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
|
||||
"""Renames Python keywords by adding a prefix.
|
||||
|
||||
Example:
|
||||
```
|
||||
rename_python_keywords('if') -> 'param_if'
|
||||
rename_python_keywords('for') -> 'param_for'
|
||||
```
|
||||
|
||||
Args:
|
||||
s: The input string.
|
||||
prefix: The prefix to add to the keyword.
|
||||
|
||||
Returns:
|
||||
The renamed string.
|
||||
"""
|
||||
if keyword.iskeyword(s):
|
||||
return prefix + s
|
||||
return s
|
||||
|
||||
|
||||
class ApiParameter(BaseModel):
|
||||
"""Data class representing a function parameter."""
|
||||
|
||||
original_name: str
|
||||
param_location: str
|
||||
param_schema: Union[str, Schema]
|
||||
description: Optional[str] = ''
|
||||
py_name: Optional[str] = ''
|
||||
type_value: type[Any] = Field(default=None, init_var=False)
|
||||
type_hint: str = Field(default=None, init_var=False)
|
||||
|
||||
def model_post_init(self, _: Any):
|
||||
self.py_name = (
|
||||
self.py_name
|
||||
if self.py_name
|
||||
else rename_python_keywords(to_snake_case(self.original_name))
|
||||
)
|
||||
if isinstance(self.param_schema, str):
|
||||
self.param_schema = Schema.model_validate_json(self.param_schema)
|
||||
|
||||
self.description = self.description or self.param_schema.description or ''
|
||||
self.type_value = TypeHintHelper.get_type_value(self.param_schema)
|
||||
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
|
||||
return self
|
||||
|
||||
@model_serializer
|
||||
def _serialize(self):
|
||||
return {
|
||||
'original_name': self.original_name,
|
||||
'param_location': self.param_location,
|
||||
'param_schema': self.param_schema,
|
||||
'description': self.description,
|
||||
'py_name': self.py_name,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.py_name}: {self.type_hint}'
|
||||
|
||||
def to_arg_string(self):
|
||||
"""Converts the parameter to an argument string for function call."""
|
||||
return f'{self.py_name}={self.py_name}'
|
||||
|
||||
def to_dict_property(self):
|
||||
"""Converts the parameter to a key:value string for dict property."""
|
||||
return f'"{self.py_name}": {self.py_name}'
|
||||
|
||||
def to_pydoc_string(self):
|
||||
"""Converts the parameter to a PyDoc parameter docstr."""
|
||||
return PydocHelper.generate_param_doc(self)
|
||||
|
||||
|
||||
class TypeHintHelper:
|
||||
"""Helper class for generating type hints."""
|
||||
|
||||
@staticmethod
|
||||
def get_type_value(schema: Schema) -> Any:
|
||||
"""Generates the Python type value for a given parameter."""
|
||||
param_type = schema.type if schema.type else Any
|
||||
|
||||
if param_type == 'integer':
|
||||
return int
|
||||
elif param_type == 'number':
|
||||
return float
|
||||
elif param_type == 'boolean':
|
||||
return bool
|
||||
elif param_type == 'string':
|
||||
return str
|
||||
elif param_type == 'array':
|
||||
items_type = Any
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return List[Dict[str, Any]]
|
||||
else:
|
||||
type_map = {
|
||||
'integer': int,
|
||||
'number': float,
|
||||
'boolean': bool,
|
||||
'string': str,
|
||||
'object': Dict[str, Any],
|
||||
'array': List[Any],
|
||||
}
|
||||
return List[type_map.get(items_type, 'Any')]
|
||||
elif param_type == 'object':
|
||||
return Dict[str, Any]
|
||||
else:
|
||||
return Any
|
||||
|
||||
@staticmethod
|
||||
def get_type_hint(schema: Schema) -> str:
|
||||
"""Generates the Python type in string for a given parameter."""
|
||||
param_type = schema.type if schema.type else 'Any'
|
||||
|
||||
if param_type == 'integer':
|
||||
return 'int'
|
||||
elif param_type == 'number':
|
||||
return 'float'
|
||||
elif param_type == 'boolean':
|
||||
return 'bool'
|
||||
elif param_type == 'string':
|
||||
return 'str'
|
||||
elif param_type == 'array':
|
||||
items_type = 'Any'
|
||||
if schema.items and schema.items.type:
|
||||
items_type = schema.items.type
|
||||
|
||||
if items_type == 'object':
|
||||
return 'List[Dict[str, Any]]'
|
||||
else:
|
||||
type_map = {
|
||||
'integer': 'int',
|
||||
'number': 'float',
|
||||
'boolean': 'bool',
|
||||
'string': 'str',
|
||||
}
|
||||
return f"List[{type_map.get(items_type, 'Any')}]"
|
||||
elif param_type == 'object':
|
||||
return 'Dict[str, Any]'
|
||||
else:
|
||||
return 'Any'
|
||||
|
||||
|
||||
class PydocHelper:
|
||||
"""Helper class for generating PyDoc strings."""
|
||||
|
||||
@staticmethod
|
||||
def generate_param_doc(
|
||||
param: ApiParameter,
|
||||
) -> str:
|
||||
"""Generates a parameter documentation string.
|
||||
|
||||
Args:
|
||||
param: ApiParameter - The parameter to generate the documentation for.
|
||||
|
||||
Returns:
|
||||
str: The generated parameter Python documentation string.
|
||||
"""
|
||||
description = param.description.strip() if param.description else ''
|
||||
param_doc = f'{param.py_name} ({param.type_hint}): {description}'
|
||||
|
||||
if param.param_schema.type == 'object':
|
||||
properties = param.param_schema.properties
|
||||
if properties:
|
||||
param_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
|
||||
return param_doc
|
||||
|
||||
@staticmethod
|
||||
def generate_return_doc(responses: Dict[str, Response]) -> str:
|
||||
"""Generates a return value documentation string.
|
||||
|
||||
Args:
|
||||
responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
|
||||
Operation
|
||||
|
||||
Returns:
|
||||
str: The generated return value Python documentation string.
|
||||
"""
|
||||
return_doc = ''
|
||||
|
||||
# Only consider 2xx responses for return type hinting.
|
||||
# Returns the 2xx response with the smallest status code number and with
|
||||
# content defined.
|
||||
sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
|
||||
qualified_response = next(
|
||||
filter(
|
||||
lambda r: r[0].startswith('2') and r[1].content,
|
||||
sorted_responses,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not qualified_response:
|
||||
return ''
|
||||
response_details = qualified_response[1]
|
||||
|
||||
description = (response_details.description or '').strip()
|
||||
content = response_details.content or {}
|
||||
|
||||
# Generate return type hint and properties for the first response type.
|
||||
# TODO(cheliu): Handle multiple content types.
|
||||
for _, schema_details in content.items():
|
||||
schema = schema_details.schema_ or {}
|
||||
|
||||
# Use a dummy Parameter object for return type hinting.
|
||||
dummy_param = ApiParameter(
|
||||
original_name='', param_location='', param_schema=schema
|
||||
)
|
||||
return_doc = f'Returns ({dummy_param.type_hint}): {description}'
|
||||
|
||||
response_type = schema.type or 'Any'
|
||||
if response_type != 'object':
|
||||
break
|
||||
properties = schema.properties
|
||||
if not properties:
|
||||
break
|
||||
return_doc += ' Object properties:\n'
|
||||
for prop_name, prop_details in properties.items():
|
||||
prop_desc = prop_details.description or ''
|
||||
prop_type = TypeHintHelper.get_type_hint(prop_details)
|
||||
return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
|
||||
break
|
||||
|
||||
return return_doc
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
|
||||
from .openapi_toolset import OpenAPIToolset
|
||||
from .operation_parser import OperationParser
|
||||
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
__all__ = [
|
||||
'OpenApiSpecParser',
|
||||
'OperationEndpoint',
|
||||
'ParsedOperation',
|
||||
'OpenAPIToolset',
|
||||
'OperationParser',
|
||||
'RestApiTool',
|
||||
'to_gemini_schema',
|
||||
'snake_to_lower_camel',
|
||||
'AuthPreparationState',
|
||||
'ToolAuthHandler',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user