Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,32 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
from .openapi_toolset import OpenAPIToolset
from .operation_parser import OperationParser
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
from .tool_auth_handler import ToolAuthHandler
__all__ = [
'OpenApiSpecParser',
'OperationEndpoint',
'ParsedOperation',
'OpenAPIToolset',
'OperationParser',
'RestApiTool',
'to_gemini_schema',
'snake_to_lower_camel',
'AuthPreparationState',
'ToolAuthHandler',
]

View File

@@ -0,0 +1,231 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from fastapi.openapi.models import Operation
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .operation_parser import OperationParser
class OperationEndpoint(BaseModel):
base_url: str
path: str
method: str
class ParsedOperation(BaseModel):
name: str
description: str
endpoint: OperationEndpoint
operation: Operation
parameters: List[ApiParameter]
return_value: ApiParameter
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
additional_context: Optional[Any] = None
class OpenApiSpecParser:
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
This class takes an OpenApiOperation object and provides methods to generate:
1. A string representation of a Python function that handles the operation.
2. A JSON schema representing the input parameters of the operation.
3. A callable Python object (a function) that can execute the operation.
"""
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
ParsedOperation objects are further used for generating RestApiTool.
Args:
openapi_spec_dict: A dictionary representing the OpenAPI specification.
Returns:
A list of ParsedOperation objects.
"""
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
operations = self._collect_operations(openapi_spec_dict)
return operations
def _collect_operations(
self, openapi_spec: Dict[str, Any]
) -> List[ParsedOperation]:
"""Collects operations from an OpenAPI spec."""
operations = []
# Taking first server url, or default to empty string if not present
base_url = ""
if openapi_spec.get("servers"):
base_url = openapi_spec["servers"][0].get("url", "")
# Get global security scheme (if any)
global_scheme_name = None
if openapi_spec.get("security"):
# Use first scheme by default.
scheme_names = list(openapi_spec["security"][0].keys())
global_scheme_name = scheme_names[0] if scheme_names else None
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
for path, path_item in openapi_spec.get("paths", {}).items():
if path_item is None:
continue
for method in (
"get",
"post",
"put",
"delete",
"patch",
"head",
"options",
"trace",
):
operation_dict = path_item.get(method)
if operation_dict is None:
continue
# If operation ID is missing, assign an operation id based on path
# and method
if "operationId" not in operation_dict:
temp_id = to_snake_case(f"{path}_{method}")
operation_dict["operationId"] = temp_id
url = OperationEndpoint(base_url=base_url, path=path, method=method)
operation = Operation.model_validate(operation_dict)
operation_parser = OperationParser(operation)
# Check for operation-specific auth scheme
auth_scheme_name = operation_parser.get_auth_scheme_name()
auth_scheme_name = (
auth_scheme_name if auth_scheme_name else global_scheme_name
)
auth_scheme = (
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
)
parsed_op = ParsedOperation(
name=operation_parser.get_function_name(),
description=operation.description or operation.summary or "",
endpoint=url,
operation=operation,
parameters=operation_parser.get_parameters(),
return_value=operation_parser.get_return_value(),
auth_scheme=auth_scheme,
auth_credential=None, # Placeholder
additional_context={},
)
operations.append(parsed_op)
return operations
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively resolves all $ref references in an OpenAPI specification.
Handles circular references correctly.
Args:
openapi_spec: A dictionary representing the OpenAPI specification.
Returns:
A dictionary representing the OpenAPI specification with all references
resolved.
"""
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
resolved_cache = {} # Cache resolved references
def resolve_ref(ref_string, current_doc):
"""Resolves a single $ref string."""
parts = ref_string.split("/")
if parts[0] != "#":
raise ValueError(f"External references not supported: {ref_string}")
current = current_doc
for part in parts[1:]:
if part in current:
current = current[part]
else:
return None # Reference not found
return current
def recursive_resolve(obj, current_doc, seen_refs=None):
"""Recursively resolves references, handling circularity.
Args:
obj: The object to traverse.
current_doc: Document to search for refs.
seen_refs: A set to track already-visited references (for circularity
detection).
Returns:
The resolved object.
"""
if seen_refs is None:
seen_refs = set() # Initialize the set if it's the first call
if isinstance(obj, dict):
if "$ref" in obj and isinstance(obj["$ref"], str):
ref_string = obj["$ref"]
# Check for circularity
if ref_string in seen_refs and ref_string not in resolved_cache:
# Circular reference detected! Return a *copy* of the object,
# but *without* the $ref. This breaks the cycle while
# still maintaining the overall structure.
return {k: v for k, v in obj.items() if k != "$ref"}
seen_refs.add(ref_string) # Add the reference to the set
# Check if we have a cached resolved value
if ref_string in resolved_cache:
return copy.deepcopy(resolved_cache[ref_string])
resolved_value = resolve_ref(ref_string, current_doc)
if resolved_value is not None:
# Recursively resolve the *resolved* value,
# passing along the 'seen_refs' set
resolved_value = recursive_resolve(
resolved_value, current_doc, seen_refs
)
resolved_cache[ref_string] = resolved_value
return copy.deepcopy(resolved_value) # return the cached result
else:
return obj # return original if no resolved value.
else:
new_dict = {}
for key, value in obj.items():
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
return new_dict
elif isinstance(obj, list):
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
else:
return obj
return recursive_resolve(openapi_spec, openapi_spec)

View File

@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Literal
from typing import Optional
import yaml
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from .openapi_spec_parser import OpenApiSpecParser
from .rest_api_tool import RestApiTool
logger = logging.getLogger(__name__)
class OpenAPIToolset:
"""Class for parsing OpenAPI spec into a list of RestApiTool.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
"""
def __init__(
self,
*,
spec_dict: Optional[Dict[str, Any]] = None,
spec_str: Optional[str] = None,
spec_str_type: Literal["json", "yaml"] = "json",
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
):
"""Initializes the OpenAPIToolset.
Usage:
```
# Initialize OpenAPI toolset from a spec string.
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
spec_str_type="json")
# Or, initialize OpenAPI toolset from a spec dictionary.
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
# Add all tools to an agent.
agent = Agent(
tools=[*openapi_toolset.get_tools()]
)
# Or, add a single tool to an agent.
agent = Agent(
tools=[openapi_toolset.get_tool('tool_name')]
)
```
Args:
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
instead of loading the spec from a string.
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
when spec_dict is not provided.
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
"yaml".
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
auth_credential: The auth credential to use for all tools. Use
AuthCredential or use helpers in
`google.adk.tools.openapi_tool.auth.auth_helpers`
"""
if not spec_dict:
spec_dict = self._load_spec(spec_str, spec_str_type)
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
if auth_scheme or auth_credential:
self._configure_auth_all(auth_scheme, auth_credential)
def _configure_auth_all(
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
):
"""Configure auth scheme and credential for all tools."""
for tool in self.tools:
if auth_scheme:
tool.configure_auth_scheme(auth_scheme)
if auth_credential:
tool.configure_auth_credential(auth_credential)
def get_tools(self) -> List[RestApiTool]:
"""Get all tools in the toolset."""
return self.tools
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
"""Get a tool by name."""
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
return next(matching_tool, None)
def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary."""
if spec_type == "json":
return json.loads(spec_str)
elif spec_type == "yaml":
return yaml.safe_load(spec_str)
else:
raise ValueError(f"Unsupported spec type: {spec_type}")
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
"""Parse OpenAPI spec into a list of RestApiTool."""
operations = OpenApiSpecParser().parse(openapi_spec_dict)
tools = []
for o in operations:
tool = RestApiTool.from_parsed_operation(o)
logger.info("Parsed tool: %s", tool.name)
tools.append(tool)
return tools

View File

@@ -0,0 +1,260 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
class OperationParser:
"""Generates parameters for Python functions from an OpenAPI operation.
This class processes an OpenApiOperation object and provides helper methods
to extract information needed to generate Python function declarations,
docstrings, signatures, and JSON schemas. It handles parameter processing,
name deduplication, and type hint generation.
"""
def __init__(
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
):
"""Initializes the OperationParser with an OpenApiOperation.
Args:
operation: The OpenApiOperation object or a dictionary to process.
should_parse: Whether to parse the operation during initialization.
"""
if isinstance(operation, dict):
self.operation = Operation.model_validate(operation)
elif isinstance(operation, str):
self.operation = Operation.model_validate_json(operation)
else:
self.operation = operation
self.params: List[ApiParameter] = []
self.return_value: Optional[ApiParameter] = None
if should_parse:
self._process_operation_parameters()
self._process_request_body()
self._process_return_value()
self._dedupe_param_names()
@classmethod
def load(
cls,
operation: Union[Operation, Dict[str, Any]],
params: List[ApiParameter],
return_value: Optional[ApiParameter] = None,
) -> 'OperationParser':
parser = cls(operation, should_parse=False)
parser.params = params
parser.return_value = return_value
return parser
def _process_operation_parameters(self):
"""Processes parameters from the OpenAPI operation."""
parameters = self.operation.parameters or []
for param in parameters:
if isinstance(param, Parameter):
original_name = param.name
description = param.description or ''
location = param.in_ or ''
schema = param.schema_ or {} # Use schema_ instead of .schema
self.params.append(
ApiParameter(
original_name=original_name,
param_location=location,
param_schema=schema,
description=description,
)
)
def _process_request_body(self):
"""Processes the request body from the OpenAPI operation."""
request_body = self.operation.requestBody
if not request_body:
return
content = request_body.content or {}
if not content:
return
# If request body is an object, expand the properties as parameters
for _, media_type_object in content.items():
schema = media_type_object.schema_ or {}
description = request_body.description or ''
if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items():
self.params.append(
ApiParameter(
original_name=prop_name,
param_location='body',
param_schema=prop_details,
description=prop_details.description,
)
)
elif schema and schema.type == 'array':
self.params.append(
ApiParameter(
original_name='array',
param_location='body',
param_schema=schema,
description=description,
)
)
else:
self.params.append(
# Empty name for unnamed body param
ApiParameter(
original_name='',
param_location='body',
param_schema=schema,
description=description,
)
)
break # Process first mime type only
def _dedupe_param_names(self):
"""Deduplicates parameter names to avoid conflicts."""
params_cnt = {}
for param in self.params:
name = param.py_name
if name not in params_cnt:
params_cnt[name] = 0
else:
params_cnt[name] += 1
param.py_name = f'{name}_{params_cnt[name] -1}'
def _process_return_value(self) -> Parameter:
"""Returns a Parameter object representing the return type."""
responses = self.operation.responses or {}
# Default to Any if no 2xx response or if schema is missing
return_schema = Schema(type='Any')
# Take the 20x response with the smallest response code.
valid_codes = list(
filter(lambda k: k.startswith('2'), list(responses.keys()))
)
min_20x_status_code = min(valid_codes) if valid_codes else None
if min_20x_status_code and responses[min_20x_status_code].content:
content = responses[min_20x_status_code].content
for mime_type in content:
if content[mime_type].schema_:
return_schema = content[mime_type].schema_
break
self.return_value = ApiParameter(
original_name='',
param_location='',
param_schema=return_schema,
)
def get_function_name(self) -> str:
"""Returns the generated function name."""
operation_id = self.operation.operationId
if not operation_id:
raise ValueError('Operation ID is missing')
return to_snake_case(operation_id)[:60]
def get_return_type_hint(self) -> str:
"""Returns the return type hint string (like 'str', 'int', etc.)."""
return self.return_value.type_hint
def get_return_type_value(self) -> Any:
"""Returns the return type value (like str, int, List[str], etc.)."""
return self.return_value.type_value
def get_parameters(self) -> List[ApiParameter]:
"""Returns the list of Parameter objects."""
return self.params
def get_return_value(self) -> ApiParameter:
"""Returns the list of Parameter objects."""
return self.return_value
def get_auth_scheme_name(self) -> str:
"""Returns the name of the auth scheme for this operation from the spec."""
if self.operation.security:
scheme_name = list(self.operation.security[0].keys())[0]
return scheme_name
return ''
def get_pydoc_string(self) -> str:
"""Returns the generated PyDoc string."""
pydoc_params = [param.to_pydoc_string() for param in self.params]
pydoc_description = (
self.operation.summary or self.operation.description or ''
)
pydoc_return = PydocHelper.generate_return_doc(
self.operation.responses or {}
)
pydoc_arg_list = chr(10).join(
f' {param_doc}' for param_doc in pydoc_params
)
return dedent(f"""
\"\"\"{pydoc_description}
Args:
{pydoc_arg_list}
{pydoc_return}
\"\"\"
""").strip()
def get_json_schema(self) -> Dict[str, Any]:
"""Returns the JSON schema for the function arguments."""
properties = {
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
for p in self.params
}
return {
'properties': properties,
'required': [p.py_name for p in self.params],
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
'type': 'object',
}
def get_signature_parameters(self) -> List[inspect.Parameter]:
"""Returns a list of inspect.Parameter objects for the function."""
return [
inspect.Parameter(
param.py_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=param.type_value,
)
for param in self.params
]
def get_annotations(self) -> Dict[str, Any]:
"""Returns a dictionary of parameter annotations for the function."""
annotations = {p.py_name: p.type_value for p in self.params}
annotations['return'] = self.get_return_type_value()
return annotations

View File

@@ -0,0 +1,496 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from fastapi.openapi.models import Operation
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
import requests
from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..common.common import ApiParameter
from ..common.common import to_snake_case
from .openapi_spec_parser import OperationEndpoint
from .openapi_spec_parser import ParsedOperation
from .operation_parser import OperationParser
from .tool_auth_handler import ToolAuthHandler
def snake_to_lower_camel(snake_case_string: str):
"""Converts a snake_case string to a lower_camel_case string.
Args:
snake_case_string: The input snake_case string.
Returns:
The lower_camel_case string.
"""
if "_" not in snake_case_string:
return snake_case_string
return "".join([
s.lower() if i == 0 else s.capitalize()
for i, s in enumerate(snake_case_string.split("_"))
])
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
Args:
openapi_schema: The OpenAPI schema dictionary.
Returns:
A Pydantic Schema object. Returns None if input is None.
Raises TypeError if input is not a dict.
"""
if openapi_schema is None:
return None
if not isinstance(openapi_schema, dict):
raise TypeError("openapi_schema must be a dictionary")
pydantic_schema_data = {}
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
if snake_case_key in Schema.model_fields:
if snake_case_key in ["title", "default", "format"]:
# Ignore these fields as Gemini backend doesn't recognize them, and will
# throw exception if they appear in the schema.
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
if snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}
elif snake_case_key == "items" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
elif snake_case_key == "any_of" and isinstance(value, list):
pydantic_schema_data[snake_case_key] = [
to_gemini_schema(item) for item in value
]
# Important: Handle cases where the OpenAPI schema might contain lists
# or other structures that need to be recursively processed.
elif isinstance(value, list) and snake_case_key not in (
"enum",
"required",
"property_ordering",
):
new_list = []
for item in value:
if isinstance(item, dict):
new_list.append(to_gemini_schema(item))
else:
new_list.append(item)
pydantic_schema_data[snake_case_key] = new_list
elif isinstance(value, dict) and snake_case_key not in ("properties"):
# Handle dictionary which is neither properties or items
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
else:
# Simple value assignment (int, str, bool, etc.)
pydantic_schema_data[snake_case_key] = value
return Schema(**pydantic_schema_data)
AuthPreparationState = Literal["pending", "done"]
class RestApiTool(BaseTool):
"""A generic tool that interacts with a REST API.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
def __init__(
self,
name: str,
description: str,
endpoint: Union[OperationEndpoint, str],
operation: Union[Operation, str],
auth_scheme: Optional[Union[AuthScheme, str]] = None,
auth_credential: Optional[Union[AuthCredential, str]] = None,
should_parse_operation=True,
):
"""Initializes the RestApiTool with the given parameters.
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
auth_scheme and auth_credential.
Args:
name: The name of the tool.
description: The description of the tool.
endpoint: Include the base_url, path, and method of the tool.
operation: Pydantic object or a dict. Representing the OpenAPI Operation
object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
SecurityScheme object
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
auth_credential: The authentication credential of the tool.
should_parse_operation: Whether to parse the operation.
"""
# Gemini restrict the length of function name to be less than 64 characters
self.name = name[:60]
self.description = description
self.endpoint = (
OperationEndpoint.model_validate_json(endpoint)
if isinstance(endpoint, str)
else endpoint
)
self.operation = (
Operation.model_validate_json(operation)
if isinstance(operation, str)
else operation
)
self.auth_credential, self.auth_scheme = None, None
self.configure_auth_credential(auth_credential)
self.configure_auth_scheme(auth_scheme)
# Private properties
self.credential_exchanger = AutoAuthCredentialExchanger()
if should_parse_operation:
self._operation_parser = OperationParser(self.operation)
@classmethod
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
"""Initializes the RestApiTool from a ParsedOperation object.
Args:
parsed: A ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation_parser = OperationParser.load(
parsed.operation, parsed.parameters, parsed.return_value
)
tool_name = to_snake_case(operation_parser.get_function_name())
generated = cls(
name=tool_name,
description=parsed.operation.description
or parsed.operation.summary
or "",
endpoint=parsed.endpoint,
operation=parsed.operation,
auth_scheme=parsed.auth_scheme,
auth_credential=parsed.auth_credential,
)
generated._operation_parser = operation_parser
return generated
@classmethod
def from_parsed_operation_str(
cls, parsed_operation_str: str
) -> "RestApiTool":
"""Initializes the RestApiTool from a dict.
Args:
parsed: A dict representation of a ParsedOperation object.
Returns:
A RestApiTool object.
"""
operation = ParsedOperation.model_validate_json(parsed_operation_str)
return RestApiTool.from_parsed_operation(operation)
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
def configure_auth_scheme(
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
):
"""Configures the authentication scheme for the API call.
Args:
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
converted to a AuthScheme object.
"""
if isinstance(auth_scheme, dict):
auth_scheme = dict_to_auth_scheme(auth_scheme)
self.auth_scheme = auth_scheme
def configure_auth_credential(
self, auth_credential: Optional[Union[AuthCredential, str]] = None
):
"""Configures the authentication credential for the API call.
Args:
auth_credential: AuthCredential|dict - The authentication credential.
The dict is converted to an AuthCredential object.
"""
if isinstance(auth_credential, str):
auth_credential = AuthCredential.model_validate_json(auth_credential)
self.auth_credential = auth_credential
def _prepare_auth_request_params(
self,
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
# Handle Authentication
if not auth_scheme or not auth_credential:
return
return credential_to_param(auth_scheme, auth_credential)
def _prepare_request_params(
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Prepares the request parameters for the API call.
Args:
parameters: A list of ApiParameter objects representing the parameters
for the API call.
kwargs: The keyword arguments passed to the call function from the Tool
caller.
Returns:
A dictionary containing the request parameters for the API call. This
initializes a requests.request() call.
Example:
self._prepare_request_params({"input_id": "test-id"})
"""
method = self.endpoint.method.lower()
if not method:
raise ValueError("Operation method not found.")
path_params: Dict[str, Any] = {}
query_params: Dict[str, Any] = {}
header_params: Dict[str, Any] = {}
cookie_params: Dict[str, Any] = {}
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
# Fill in path, query, header and cookie parameters to the request
for param_k, v in kwargs.items():
param_obj = params_map.get(param_k)
if not param_obj:
continue # If input arg not in the ApiParameter list, ignore it.
original_k = param_obj.original_name
param_location = param_obj.param_location
if param_location == "path":
path_params[original_k] = v
elif param_location == "query":
if v:
query_params[original_k] = v
elif param_location == "header":
header_params[original_k] = v
elif param_location == "cookie":
cookie_params[original_k] = v
# Construct URL
base_url = self.endpoint.base_url or ""
base_url = base_url[:-1] if base_url.endswith("/") else base_url
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
# Construct body
body_kwargs: Dict[str, Any] = {}
request_body = self.operation.requestBody
if request_body:
for mime_type, media_type_object in request_body.content.items():
schema = media_type_object.schema_
body_data = None
if schema.type == "object":
body_data = {}
for param in parameters:
if param.param_location == "body" and param.py_name in kwargs:
body_data[param.original_name] = kwargs[param.py_name]
elif schema.type == "array":
for param in parameters:
if param.param_location == "body" and param.py_name == "array":
body_data = kwargs.get("array")
break
else: # like string
for param in parameters:
# original_name = '' indicating this param applies to the full body.
if param.param_location == "body" and not param.original_name:
body_data = (
kwargs.get(param.py_name) if param.py_name in kwargs else None
)
break
if mime_type == "application/json" or mime_type.endswith("+json"):
if body_data is not None:
body_kwargs["json"] = body_data
elif mime_type == "application/x-www-form-urlencoded":
body_kwargs["data"] = body_data
elif mime_type == "multipart/form-data":
body_kwargs["files"] = body_data
elif mime_type == "application/octet-stream":
body_kwargs["data"] = body_data
elif mime_type == "text/plain":
body_kwargs["data"] = body_data
if mime_type:
header_params["Content-Type"] = mime_type
break # Process only the first mime_type
filtered_query_params: Dict[str, Any] = {
k: v for k, v in query_params.items() if v is not None
}
request_params: Dict[str, Any] = {
"method": method,
"url": url,
"params": filtered_query_params,
"headers": header_params,
"cookies": cookie_params,
**body_kwargs,
}
return request_params
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
return self.call(args=args, tool_context=tool_context)
def call(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
"""Executes the REST API call.
Args:
args: Keyword arguments representing the operation parameters.
tool_context: The tool context (not used here, but required by the
interface).
Returns:
The API response as a dictionary.
"""
# Prepare auth credentials for the API call
tool_auth_handler = ToolAuthHandler.from_tool_context(
tool_context, self.auth_scheme, self.auth_credential
)
auth_result = tool_auth_handler.prepare_auth_credentials()
auth_state, auth_scheme, auth_credential = (
auth_result.state,
auth_result.auth_scheme,
auth_result.auth_credential,
)
if auth_state == "pending":
return {
"pending": True,
"message": "Needs your authorization to access your data.",
}
# Attach parameters from auth into main parameters list
api_params, api_args = self._operation_parser.get_parameters().copy(), args
if auth_credential:
# Attach parameters from auth into main parameters list
auth_param, auth_args = self._prepare_auth_request_params(
auth_scheme, auth_credential
)
if auth_param and auth_args:
api_params = [auth_param] + api_params
api_args.update(auth_args)
# Got all parameters. Call the API.
request_params = self._prepare_request_params(api_params, api_args)
response = requests.request(**request_params)
# Parse API response
try:
response.raise_for_status() # Raise HTTPError for bad responses
return response.json() # Try to decode JSON
except requests.exceptions.HTTPError:
error_details = response.content.decode("utf-8")
return {
"error": (
f"Tool {self.name} execution failed. Analyze this execution error"
" and your inputs. Retry with adjustments if applicable. But"
" make sure don't retry more than 3 times. Execution Error:"
f" {error_details}"
)
}
except ValueError:
return {"text": response.text} # Return text if not JSON
def __str__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}")'
)
def __repr__(self):
return (
f'RestApiTool(name="{self.name}", description="{self.description}",'
f' endpoint="{self.endpoint}", operation="{self.operation}",'
f' auth_scheme="{self.auth_scheme}",'
f' auth_credential="{self.auth_credential}")'
)

View File

@@ -0,0 +1,268 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Literal
from typing import Optional
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from ....auth.auth_credential import AuthCredential
from ....auth.auth_credential import AuthCredentialTypes
from ....auth.auth_schemes import AuthScheme
from ....auth.auth_schemes import AuthSchemeType
from ....auth.auth_tool import AuthConfig
from ...tool_context import ToolContext
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
logger = logging.getLogger(__name__)
AuthPreparationState = Literal["pending", "done"]
class AuthPreparationResult(BaseModel):
"""Result of the credential preparation process."""
state: AuthPreparationState
auth_scheme: Optional[AuthScheme] = None
auth_credential: Optional[AuthCredential] = None
class ToolContextCredentialStore:
"""Handles storage and retrieval of credentials within a ToolContext."""
def __init__(self, tool_context: ToolContext):
self.tool_context = tool_context
def get_credential_key(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> str:
"""Generates a unique key for the given auth scheme and credential."""
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
# no need to prepend temp: namespace, session state is a copy, changes to
# it won't be persisted , only changes in event_action.state_delta will be
# persisted. temp: namespace will be cleared after current run. but tool
# want access token to be there stored across runs
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
def get_credential(
self,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
) -> Optional[AuthCredential]:
if not self.tool_context:
return None
token_key = self.get_credential_key(auth_scheme, auth_credential)
# TODO try not to use session state, this looks a hacky way, depend on
# session implementation, we don't want session to persist the token,
# meanwhile we want the token shared across runs.
serialized_credential = self.tool_context.state.get(token_key)
if not serialized_credential:
return None
return AuthCredential.model_validate(serialized_credential)
def store_credential(
self,
key: str,
auth_credential: Optional[AuthCredential],
):
if self.tool_context:
serializable_credential = jsonable_encoder(
auth_credential, exclude_none=True
)
self.tool_context.state[key] = serializable_credential
def remove_credential(self, key: str):
del self.tool_context.state[key]
class ToolAuthHandler:
"""Handles the preparation and exchange of authentication credentials for tools."""
def __init__(
self,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
credential_store: Optional["ToolContextCredentialStore"] = None,
):
self.tool_context = tool_context
self.auth_scheme = (
auth_scheme.model_copy(deep=True) if auth_scheme else None
)
self.auth_credential = (
auth_credential.model_copy(deep=True) if auth_credential else None
)
self.credential_exchanger = (
credential_exchanger or AutoAuthCredentialExchanger()
)
self.credential_store = credential_store
self.should_store_credential = True
@classmethod
def from_tool_context(
cls,
tool_context: ToolContext,
auth_scheme: Optional[AuthScheme],
auth_credential: Optional[AuthCredential],
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
) -> "ToolAuthHandler":
"""Creates a ToolAuthHandler instance from a ToolContext."""
credential_store = ToolContextCredentialStore(tool_context)
return cls(
tool_context,
auth_scheme,
auth_credential,
credential_exchanger,
credential_store,
)
def _handle_existing_credential(
self,
) -> Optional[AuthPreparationResult]:
"""Checks for and returns an existing, exchanged credential."""
if self.credential_store:
existing_credential = self.credential_store.get_credential(
self.auth_scheme, self.auth_credential
)
if existing_credential:
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=existing_credential,
)
return None
def _exchange_credential(
self, auth_credential: AuthCredential
) -> Optional[AuthPreparationResult]:
"""Handles an OpenID Connect authorization response."""
exchanged_credential = None
try:
exchanged_credential = self.credential_exchanger.exchange_credential(
self.auth_scheme, auth_credential
)
except Exception as e:
logger.error("Failed to exchange credential: %s", e)
return exchanged_credential
def _store_credential(self, auth_credential: AuthCredential) -> None:
"""stores the auth_credential."""
if self.credential_store:
key = self.credential_store.get_credential_key(
self.auth_scheme, self.auth_credential
)
self.credential_store.store_credential(key, auth_credential)
def _reqeust_credential(self) -> None:
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
if self.auth_scheme.type_ in (
AuthSchemeType.openIdConnect,
AuthSchemeType.oauth2,
):
if not self.auth_credential or not self.auth_credential.oauth2:
raise ValueError(
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
"Please create AuthCredential using OAuth2Auth."
)
if not self.auth_credential.oauth2.client_id:
raise AuthCredentialMissingError(
"OAuth2 credentials client_id is missing."
)
if not self.auth_credential.oauth2.client_secret:
raise AuthCredentialMissingError(
"OAuth2 credentials client_secret is missing."
)
self.tool_context.request_credential(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
return None
def _get_auth_response(self) -> AuthCredential:
return self.tool_context.get_auth_response(
AuthConfig(
auth_scheme=self.auth_scheme,
raw_auth_credential=self.auth_credential,
)
)
def _request_credential(self, auth_config: AuthConfig):
if not self.tool_context:
return
self.tool_context.request_credential(auth_config)
def prepare_auth_credentials(
self,
) -> AuthPreparationResult:
"""Prepares authentication credentials, handling exchange and user interaction."""
# no auth is needed
if not self.auth_scheme:
return AuthPreparationResult(state="done")
# Check for existing credential.
existing_result = self._handle_existing_credential()
if existing_result:
return existing_result
# fetch credential from adk framework
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
# multi-step exchange:
# client_id , client_secret -> auth_uri -> auth_code -> access_token
# -> bearer token
# adk framework supports exchange access_token already
fetched_credential = self._get_auth_response() or self.auth_credential
exchanged_credential = self._exchange_credential(fetched_credential)
if exchanged_credential:
self._store_credential(exchanged_credential)
return AuthPreparationResult(
state="done",
auth_scheme=self.auth_scheme,
auth_credential=exchanged_credential,
)
else:
self._reqeust_credential()
return AuthPreparationResult(
state="pending",
auth_scheme=self.auth_scheme,
auth_credential=self.auth_credential,
)