mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-22 21:32:19 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
|
||||
from .openapi_toolset import OpenAPIToolset
|
||||
from .operation_parser import OperationParser
|
||||
from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
__all__ = [
|
||||
'OpenApiSpecParser',
|
||||
'OperationEndpoint',
|
||||
'ParsedOperation',
|
||||
'OpenAPIToolset',
|
||||
'OperationParser',
|
||||
'RestApiTool',
|
||||
'to_gemini_schema',
|
||||
'snake_to_lower_camel',
|
||||
'AuthPreparationState',
|
||||
'ToolAuthHandler',
|
||||
]
|
||||
@@ -0,0 +1,231 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .operation_parser import OperationParser
|
||||
|
||||
|
||||
class OperationEndpoint(BaseModel):
|
||||
base_url: str
|
||||
path: str
|
||||
method: str
|
||||
|
||||
|
||||
class ParsedOperation(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
endpoint: OperationEndpoint
|
||||
operation: Operation
|
||||
parameters: List[ApiParameter]
|
||||
return_value: ApiParameter
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
additional_context: Optional[Any] = None
|
||||
|
||||
|
||||
class OpenApiSpecParser:
|
||||
"""Generates Python code, JSON schema, and callables for an OpenAPI operation.
|
||||
|
||||
This class takes an OpenApiOperation object and provides methods to generate:
|
||||
1. A string representation of a Python function that handles the operation.
|
||||
2. A JSON schema representing the input parameters of the operation.
|
||||
3. A callable Python object (a function) that can execute the operation.
|
||||
"""
|
||||
|
||||
def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
|
||||
"""Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
|
||||
|
||||
ParsedOperation objects are further used for generating RestApiTool.
|
||||
|
||||
Args:
|
||||
openapi_spec_dict: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A list of ParsedOperation objects.
|
||||
"""
|
||||
|
||||
openapi_spec_dict = self._resolve_references(openapi_spec_dict)
|
||||
operations = self._collect_operations(openapi_spec_dict)
|
||||
return operations
|
||||
|
||||
def _collect_operations(
|
||||
self, openapi_spec: Dict[str, Any]
|
||||
) -> List[ParsedOperation]:
|
||||
"""Collects operations from an OpenAPI spec."""
|
||||
operations = []
|
||||
|
||||
# Taking first server url, or default to empty string if not present
|
||||
base_url = ""
|
||||
if openapi_spec.get("servers"):
|
||||
base_url = openapi_spec["servers"][0].get("url", "")
|
||||
|
||||
# Get global security scheme (if any)
|
||||
global_scheme_name = None
|
||||
if openapi_spec.get("security"):
|
||||
# Use first scheme by default.
|
||||
scheme_names = list(openapi_spec["security"][0].keys())
|
||||
global_scheme_name = scheme_names[0] if scheme_names else None
|
||||
|
||||
auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
|
||||
|
||||
for path, path_item in openapi_spec.get("paths", {}).items():
|
||||
if path_item is None:
|
||||
continue
|
||||
|
||||
for method in (
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"head",
|
||||
"options",
|
||||
"trace",
|
||||
):
|
||||
operation_dict = path_item.get(method)
|
||||
if operation_dict is None:
|
||||
continue
|
||||
|
||||
# If operation ID is missing, assign an operation id based on path
|
||||
# and method
|
||||
if "operationId" not in operation_dict:
|
||||
temp_id = to_snake_case(f"{path}_{method}")
|
||||
operation_dict["operationId"] = temp_id
|
||||
|
||||
url = OperationEndpoint(base_url=base_url, path=path, method=method)
|
||||
operation = Operation.model_validate(operation_dict)
|
||||
operation_parser = OperationParser(operation)
|
||||
|
||||
# Check for operation-specific auth scheme
|
||||
auth_scheme_name = operation_parser.get_auth_scheme_name()
|
||||
auth_scheme_name = (
|
||||
auth_scheme_name if auth_scheme_name else global_scheme_name
|
||||
)
|
||||
auth_scheme = (
|
||||
auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
|
||||
)
|
||||
|
||||
parsed_op = ParsedOperation(
|
||||
name=operation_parser.get_function_name(),
|
||||
description=operation.description or operation.summary or "",
|
||||
endpoint=url,
|
||||
operation=operation,
|
||||
parameters=operation_parser.get_parameters(),
|
||||
return_value=operation_parser.get_return_value(),
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=None, # Placeholder
|
||||
additional_context={},
|
||||
)
|
||||
operations.append(parsed_op)
|
||||
|
||||
return operations
|
||||
|
||||
def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively resolves all $ref references in an OpenAPI specification.
|
||||
|
||||
Handles circular references correctly.
|
||||
|
||||
Args:
|
||||
openapi_spec: A dictionary representing the OpenAPI specification.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the OpenAPI specification with all references
|
||||
resolved.
|
||||
"""
|
||||
|
||||
openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
|
||||
resolved_cache = {} # Cache resolved references
|
||||
|
||||
def resolve_ref(ref_string, current_doc):
|
||||
"""Resolves a single $ref string."""
|
||||
parts = ref_string.split("/")
|
||||
if parts[0] != "#":
|
||||
raise ValueError(f"External references not supported: {ref_string}")
|
||||
|
||||
current = current_doc
|
||||
for part in parts[1:]:
|
||||
if part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None # Reference not found
|
||||
return current
|
||||
|
||||
def recursive_resolve(obj, current_doc, seen_refs=None):
|
||||
"""Recursively resolves references, handling circularity.
|
||||
|
||||
Args:
|
||||
obj: The object to traverse.
|
||||
current_doc: Document to search for refs.
|
||||
seen_refs: A set to track already-visited references (for circularity
|
||||
detection).
|
||||
|
||||
Returns:
|
||||
The resolved object.
|
||||
"""
|
||||
if seen_refs is None:
|
||||
seen_refs = set() # Initialize the set if it's the first call
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj and isinstance(obj["$ref"], str):
|
||||
ref_string = obj["$ref"]
|
||||
|
||||
# Check for circularity
|
||||
if ref_string in seen_refs and ref_string not in resolved_cache:
|
||||
# Circular reference detected! Return a *copy* of the object,
|
||||
# but *without* the $ref. This breaks the cycle while
|
||||
# still maintaining the overall structure.
|
||||
return {k: v for k, v in obj.items() if k != "$ref"}
|
||||
|
||||
seen_refs.add(ref_string) # Add the reference to the set
|
||||
|
||||
# Check if we have a cached resolved value
|
||||
if ref_string in resolved_cache:
|
||||
return copy.deepcopy(resolved_cache[ref_string])
|
||||
|
||||
resolved_value = resolve_ref(ref_string, current_doc)
|
||||
if resolved_value is not None:
|
||||
# Recursively resolve the *resolved* value,
|
||||
# passing along the 'seen_refs' set
|
||||
resolved_value = recursive_resolve(
|
||||
resolved_value, current_doc, seen_refs
|
||||
)
|
||||
resolved_cache[ref_string] = resolved_value
|
||||
return copy.deepcopy(resolved_value) # return the cached result
|
||||
else:
|
||||
return obj # return original if no resolved value.
|
||||
|
||||
else:
|
||||
new_dict = {}
|
||||
for key, value in obj.items():
|
||||
new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, list):
|
||||
return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
return recursive_resolve(openapi_spec, openapi_spec)
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from .openapi_spec_parser import OpenApiSpecParser
|
||||
from .rest_api_tool import RestApiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAPIToolset:
|
||||
"""Class for parsing OpenAPI spec into a list of RestApiTool.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
spec_dict: Optional[Dict[str, Any]] = None,
|
||||
spec_str: Optional[str] = None,
|
||||
spec_str_type: Literal["json", "yaml"] = "json",
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
):
|
||||
"""Initializes the OpenAPIToolset.
|
||||
|
||||
Usage:
|
||||
```
|
||||
# Initialize OpenAPI toolset from a spec string.
|
||||
openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
|
||||
spec_str_type="json")
|
||||
# Or, initialize OpenAPI toolset from a spec dictionary.
|
||||
openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
|
||||
|
||||
# Add all tools to an agent.
|
||||
agent = Agent(
|
||||
tools=[*openapi_toolset.get_tools()]
|
||||
)
|
||||
# Or, add a single tool to an agent.
|
||||
agent = Agent(
|
||||
tools=[openapi_toolset.get_tool('tool_name')]
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
spec_dict: The OpenAPI spec dictionary. If provided, it will be used
|
||||
instead of loading the spec from a string.
|
||||
spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
|
||||
when spec_dict is not provided.
|
||||
spec_str_type: The type of the OpenAPI spec string. Can be "json" or
|
||||
"yaml".
|
||||
auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
|
||||
helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
auth_credential: The auth credential to use for all tools. Use
|
||||
AuthCredential or use helpers in
|
||||
`google.adk.tools.openapi_tool.auth.auth_helpers`
|
||||
"""
|
||||
if not spec_dict:
|
||||
spec_dict = self._load_spec(spec_str, spec_str_type)
|
||||
self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
|
||||
if auth_scheme or auth_credential:
|
||||
self._configure_auth_all(auth_scheme, auth_credential)
|
||||
|
||||
def _configure_auth_all(
|
||||
self, auth_scheme: AuthScheme, auth_credential: AuthCredential
|
||||
):
|
||||
"""Configure auth scheme and credential for all tools."""
|
||||
|
||||
for tool in self.tools:
|
||||
if auth_scheme:
|
||||
tool.configure_auth_scheme(auth_scheme)
|
||||
if auth_credential:
|
||||
tool.configure_auth_credential(auth_credential)
|
||||
|
||||
def get_tools(self) -> List[RestApiTool]:
|
||||
"""Get all tools in the toolset."""
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
|
||||
"""Get a tool by name."""
|
||||
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
|
||||
return next(matching_tool, None)
|
||||
|
||||
def _load_spec(
|
||||
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads the OpenAPI spec string into adictionary."""
|
||||
if spec_type == "json":
|
||||
return json.loads(spec_str)
|
||||
elif spec_type == "yaml":
|
||||
return yaml.safe_load(spec_str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported spec type: {spec_type}")
|
||||
|
||||
def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
|
||||
"""Parse OpenAPI spec into a list of RestApiTool."""
|
||||
operations = OpenApiSpecParser().parse(openapi_spec_dict)
|
||||
|
||||
tools = []
|
||||
for o in operations:
|
||||
tool = RestApiTool.from_parsed_operation(o)
|
||||
logger.info("Parsed tool: %s", tool.name)
|
||||
tools.append(tool)
|
||||
return tools
|
||||
@@ -0,0 +1,260 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import Schema
|
||||
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import PydocHelper
|
||||
from ..common.common import to_snake_case
|
||||
|
||||
|
||||
class OperationParser:
|
||||
"""Generates parameters for Python functions from an OpenAPI operation.
|
||||
|
||||
This class processes an OpenApiOperation object and provides helper methods
|
||||
to extract information needed to generate Python function declarations,
|
||||
docstrings, signatures, and JSON schemas. It handles parameter processing,
|
||||
name deduplication, and type hint generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
|
||||
):
|
||||
"""Initializes the OperationParser with an OpenApiOperation.
|
||||
|
||||
Args:
|
||||
operation: The OpenApiOperation object or a dictionary to process.
|
||||
should_parse: Whether to parse the operation during initialization.
|
||||
"""
|
||||
if isinstance(operation, dict):
|
||||
self.operation = Operation.model_validate(operation)
|
||||
elif isinstance(operation, str):
|
||||
self.operation = Operation.model_validate_json(operation)
|
||||
else:
|
||||
self.operation = operation
|
||||
|
||||
self.params: List[ApiParameter] = []
|
||||
self.return_value: Optional[ApiParameter] = None
|
||||
if should_parse:
|
||||
self._process_operation_parameters()
|
||||
self._process_request_body()
|
||||
self._process_return_value()
|
||||
self._dedupe_param_names()
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
operation: Union[Operation, Dict[str, Any]],
|
||||
params: List[ApiParameter],
|
||||
return_value: Optional[ApiParameter] = None,
|
||||
) -> 'OperationParser':
|
||||
parser = cls(operation, should_parse=False)
|
||||
parser.params = params
|
||||
parser.return_value = return_value
|
||||
return parser
|
||||
|
||||
def _process_operation_parameters(self):
|
||||
"""Processes parameters from the OpenAPI operation."""
|
||||
parameters = self.operation.parameters or []
|
||||
for param in parameters:
|
||||
if isinstance(param, Parameter):
|
||||
original_name = param.name
|
||||
description = param.description or ''
|
||||
location = param.in_ or ''
|
||||
schema = param.schema_ or {} # Use schema_ instead of .schema
|
||||
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=original_name,
|
||||
param_location=location,
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_request_body(self):
|
||||
"""Processes the request body from the OpenAPI operation."""
|
||||
request_body = self.operation.requestBody
|
||||
if not request_body:
|
||||
return
|
||||
|
||||
content = request_body.content or {}
|
||||
if not content:
|
||||
return
|
||||
|
||||
# If request body is an object, expand the properties as parameters
|
||||
for _, media_type_object in content.items():
|
||||
schema = media_type_object.schema_ or {}
|
||||
description = request_body.description or ''
|
||||
|
||||
if schema and schema.type == 'object':
|
||||
for prop_name, prop_details in schema.properties.items():
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name=prop_name,
|
||||
param_location='body',
|
||||
param_schema=prop_details,
|
||||
description=prop_details.description,
|
||||
)
|
||||
)
|
||||
|
||||
elif schema and schema.type == 'array':
|
||||
self.params.append(
|
||||
ApiParameter(
|
||||
original_name='array',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.params.append(
|
||||
# Empty name for unnamed body param
|
||||
ApiParameter(
|
||||
original_name='',
|
||||
param_location='body',
|
||||
param_schema=schema,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
break # Process first mime type only
|
||||
|
||||
def _dedupe_param_names(self):
|
||||
"""Deduplicates parameter names to avoid conflicts."""
|
||||
params_cnt = {}
|
||||
for param in self.params:
|
||||
name = param.py_name
|
||||
if name not in params_cnt:
|
||||
params_cnt[name] = 0
|
||||
else:
|
||||
params_cnt[name] += 1
|
||||
param.py_name = f'{name}_{params_cnt[name] -1}'
|
||||
|
||||
def _process_return_value(self) -> Parameter:
|
||||
"""Returns a Parameter object representing the return type."""
|
||||
responses = self.operation.responses or {}
|
||||
# Default to Any if no 2xx response or if schema is missing
|
||||
return_schema = Schema(type='Any')
|
||||
|
||||
# Take the 20x response with the smallest response code.
|
||||
valid_codes = list(
|
||||
filter(lambda k: k.startswith('2'), list(responses.keys()))
|
||||
)
|
||||
min_20x_status_code = min(valid_codes) if valid_codes else None
|
||||
|
||||
if min_20x_status_code and responses[min_20x_status_code].content:
|
||||
content = responses[min_20x_status_code].content
|
||||
for mime_type in content:
|
||||
if content[mime_type].schema_:
|
||||
return_schema = content[mime_type].schema_
|
||||
break
|
||||
|
||||
self.return_value = ApiParameter(
|
||||
original_name='',
|
||||
param_location='',
|
||||
param_schema=return_schema,
|
||||
)
|
||||
|
||||
def get_function_name(self) -> str:
|
||||
"""Returns the generated function name."""
|
||||
operation_id = self.operation.operationId
|
||||
if not operation_id:
|
||||
raise ValueError('Operation ID is missing')
|
||||
return to_snake_case(operation_id)[:60]
|
||||
|
||||
def get_return_type_hint(self) -> str:
|
||||
"""Returns the return type hint string (like 'str', 'int', etc.)."""
|
||||
return self.return_value.type_hint
|
||||
|
||||
def get_return_type_value(self) -> Any:
|
||||
"""Returns the return type value (like str, int, List[str], etc.)."""
|
||||
return self.return_value.type_value
|
||||
|
||||
def get_parameters(self) -> List[ApiParameter]:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.params
|
||||
|
||||
def get_return_value(self) -> ApiParameter:
|
||||
"""Returns the list of Parameter objects."""
|
||||
return self.return_value
|
||||
|
||||
def get_auth_scheme_name(self) -> str:
|
||||
"""Returns the name of the auth scheme for this operation from the spec."""
|
||||
if self.operation.security:
|
||||
scheme_name = list(self.operation.security[0].keys())[0]
|
||||
return scheme_name
|
||||
return ''
|
||||
|
||||
def get_pydoc_string(self) -> str:
|
||||
"""Returns the generated PyDoc string."""
|
||||
pydoc_params = [param.to_pydoc_string() for param in self.params]
|
||||
pydoc_description = (
|
||||
self.operation.summary or self.operation.description or ''
|
||||
)
|
||||
pydoc_return = PydocHelper.generate_return_doc(
|
||||
self.operation.responses or {}
|
||||
)
|
||||
pydoc_arg_list = chr(10).join(
|
||||
f' {param_doc}' for param_doc in pydoc_params
|
||||
)
|
||||
return dedent(f"""
|
||||
\"\"\"{pydoc_description}
|
||||
|
||||
Args:
|
||||
{pydoc_arg_list}
|
||||
|
||||
{pydoc_return}
|
||||
\"\"\"
|
||||
""").strip()
|
||||
|
||||
def get_json_schema(self) -> Dict[str, Any]:
|
||||
"""Returns the JSON schema for the function arguments."""
|
||||
properties = {
|
||||
p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
|
||||
for p in self.params
|
||||
}
|
||||
return {
|
||||
'properties': properties,
|
||||
'required': [p.py_name for p in self.params],
|
||||
'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
|
||||
'type': 'object',
|
||||
}
|
||||
|
||||
def get_signature_parameters(self) -> List[inspect.Parameter]:
|
||||
"""Returns a list of inspect.Parameter objects for the function."""
|
||||
return [
|
||||
inspect.Parameter(
|
||||
param.py_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=param.type_value,
|
||||
)
|
||||
for param in self.params
|
||||
]
|
||||
|
||||
def get_annotations(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary of parameter annotations for the function."""
|
||||
annotations = {p.py_name: p.type_value for p in self.params}
|
||||
annotations['return'] = self.get_return_type_value()
|
||||
return annotations
|
||||
@@ -0,0 +1,496 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....tools import BaseTool
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.auth_helpers import credential_to_param
|
||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..common.common import ApiParameter
|
||||
from ..common.common import to_snake_case
|
||||
from .openapi_spec_parser import OperationEndpoint
|
||||
from .openapi_spec_parser import ParsedOperation
|
||||
from .operation_parser import OperationParser
|
||||
from .tool_auth_handler import ToolAuthHandler
|
||||
|
||||
|
||||
def snake_to_lower_camel(snake_case_string: str):
|
||||
"""Converts a snake_case string to a lower_camel_case string.
|
||||
|
||||
Args:
|
||||
snake_case_string: The input snake_case string.
|
||||
|
||||
Returns:
|
||||
The lower_camel_case string.
|
||||
"""
|
||||
if "_" not in snake_case_string:
|
||||
return snake_case_string
|
||||
|
||||
return "".join([
|
||||
s.lower() if i == 0 else s.capitalize()
|
||||
for i, s in enumerate(snake_case_string.split("_"))
|
||||
])
|
||||
|
||||
|
||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||
|
||||
Args:
|
||||
openapi_schema: The OpenAPI schema dictionary.
|
||||
|
||||
Returns:
|
||||
A Pydantic Schema object. Returns None if input is None.
|
||||
Raises TypeError if input is not a dict.
|
||||
"""
|
||||
if openapi_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(openapi_schema, dict):
|
||||
raise TypeError("openapi_schema must be a dictionary")
|
||||
|
||||
pydantic_schema_data = {}
|
||||
|
||||
# Adding this to force adding a type to an empty dict
|
||||
# This avoid "... one_of or any_of must specify a type" error
|
||||
if not openapi_schema.get("type"):
|
||||
openapi_schema["type"] = "object"
|
||||
|
||||
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
||||
# See b/385165182
|
||||
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
||||
"properties"
|
||||
):
|
||||
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
||||
|
||||
for key, value in openapi_schema.items():
|
||||
snake_case_key = to_snake_case(key)
|
||||
# Check if the snake_case_key exists in the Schema model's fields.
|
||||
if snake_case_key in Schema.model_fields:
|
||||
if snake_case_key in ["title", "default", "format"]:
|
||||
# Ignore these fields as Gemini backend doesn't recognize them, and will
|
||||
# throw exception if they appear in the schema.
|
||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||
# supported for STRING type
|
||||
continue
|
||||
if snake_case_key == "properties" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = {
|
||||
k: to_gemini_schema(v) for k, v in value.items()
|
||||
}
|
||||
elif snake_case_key == "items" and isinstance(value, dict):
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
elif snake_case_key == "any_of" and isinstance(value, list):
|
||||
pydantic_schema_data[snake_case_key] = [
|
||||
to_gemini_schema(item) for item in value
|
||||
]
|
||||
# Important: Handle cases where the OpenAPI schema might contain lists
|
||||
# or other structures that need to be recursively processed.
|
||||
elif isinstance(value, list) and snake_case_key not in (
|
||||
"enum",
|
||||
"required",
|
||||
"property_ordering",
|
||||
):
|
||||
new_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
new_list.append(to_gemini_schema(item))
|
||||
else:
|
||||
new_list.append(item)
|
||||
pydantic_schema_data[snake_case_key] = new_list
|
||||
elif isinstance(value, dict) and snake_case_key not in ("properties"):
|
||||
# Handle dictionary which is neither properties or items
|
||||
pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
|
||||
else:
|
||||
# Simple value assignment (int, str, bool, etc.)
|
||||
pydantic_schema_data[snake_case_key] = value
|
||||
|
||||
return Schema(**pydantic_schema_data)
|
||||
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class RestApiTool(BaseTool):
|
||||
"""A generic tool that interacts with a REST API.
|
||||
|
||||
* Generates request params and body
|
||||
* Attaches auth credentials to API call.
|
||||
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
endpoint: Union[OperationEndpoint, str],
|
||||
operation: Union[Operation, str],
|
||||
auth_scheme: Optional[Union[AuthScheme, str]] = None,
|
||||
auth_credential: Optional[Union[AuthCredential, str]] = None,
|
||||
should_parse_operation=True,
|
||||
):
|
||||
"""Initializes the RestApiTool with the given parameters.
|
||||
|
||||
To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
|
||||
Example:
|
||||
```
|
||||
# Each API operation in the spec will be turned into its own tool
|
||||
# Name of the tool is the operationId of that operation, in snake case
|
||||
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||
```
|
||||
|
||||
Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
|
||||
auth_scheme and auth_credential.
|
||||
|
||||
Args:
|
||||
name: The name of the tool.
|
||||
description: The description of the tool.
|
||||
endpoint: Include the base_url, path, and method of the tool.
|
||||
operation: Pydantic object or a dict. Representing the OpenAPI Operation
|
||||
object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
|
||||
auth_scheme: The auth scheme of the tool. Representing the OpenAPI
|
||||
SecurityScheme object
|
||||
(https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
|
||||
auth_credential: The authentication credential of the tool.
|
||||
should_parse_operation: Whether to parse the operation.
|
||||
"""
|
||||
# Gemini restrict the length of function name to be less than 64 characters
|
||||
self.name = name[:60]
|
||||
self.description = description
|
||||
self.endpoint = (
|
||||
OperationEndpoint.model_validate_json(endpoint)
|
||||
if isinstance(endpoint, str)
|
||||
else endpoint
|
||||
)
|
||||
self.operation = (
|
||||
Operation.model_validate_json(operation)
|
||||
if isinstance(operation, str)
|
||||
else operation
|
||||
)
|
||||
self.auth_credential, self.auth_scheme = None, None
|
||||
|
||||
self.configure_auth_credential(auth_credential)
|
||||
self.configure_auth_scheme(auth_scheme)
|
||||
|
||||
# Private properties
|
||||
self.credential_exchanger = AutoAuthCredentialExchanger()
|
||||
if should_parse_operation:
|
||||
self._operation_parser = OperationParser(self.operation)
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a ParsedOperation object.
|
||||
|
||||
Args:
|
||||
parsed: A ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation_parser = OperationParser.load(
|
||||
parsed.operation, parsed.parameters, parsed.return_value
|
||||
)
|
||||
|
||||
tool_name = to_snake_case(operation_parser.get_function_name())
|
||||
generated = cls(
|
||||
name=tool_name,
|
||||
description=parsed.operation.description
|
||||
or parsed.operation.summary
|
||||
or "",
|
||||
endpoint=parsed.endpoint,
|
||||
operation=parsed.operation,
|
||||
auth_scheme=parsed.auth_scheme,
|
||||
auth_credential=parsed.auth_credential,
|
||||
)
|
||||
generated._operation_parser = operation_parser
|
||||
return generated
|
||||
|
||||
@classmethod
|
||||
def from_parsed_operation_str(
|
||||
cls, parsed_operation_str: str
|
||||
) -> "RestApiTool":
|
||||
"""Initializes the RestApiTool from a dict.
|
||||
|
||||
Args:
|
||||
parsed: A dict representation of a ParsedOperation object.
|
||||
|
||||
Returns:
|
||||
A RestApiTool object.
|
||||
"""
|
||||
operation = ParsedOperation.model_validate_json(parsed_operation_str)
|
||||
return RestApiTool.from_parsed_operation(operation)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._operation_parser.get_json_schema()
|
||||
parameters = to_gemini_schema(schema_dict)
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
def configure_auth_scheme(
|
||||
self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
|
||||
):
|
||||
"""Configures the authentication scheme for the API call.
|
||||
|
||||
Args:
|
||||
auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
|
||||
converted to a AuthScheme object.
|
||||
"""
|
||||
if isinstance(auth_scheme, dict):
|
||||
auth_scheme = dict_to_auth_scheme(auth_scheme)
|
||||
self.auth_scheme = auth_scheme
|
||||
|
||||
def configure_auth_credential(
|
||||
self, auth_credential: Optional[Union[AuthCredential, str]] = None
|
||||
):
|
||||
"""Configures the authentication credential for the API call.
|
||||
|
||||
Args:
|
||||
auth_credential: AuthCredential|dict - The authentication credential.
|
||||
The dict is converted to an AuthCredential object.
|
||||
"""
|
||||
if isinstance(auth_credential, str):
|
||||
auth_credential = AuthCredential.model_validate_json(auth_credential)
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
def _prepare_auth_request_params(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: AuthCredential,
|
||||
) -> Tuple[List[ApiParameter], Dict[str, Any]]:
|
||||
# Handle Authentication
|
||||
if not auth_scheme or not auth_credential:
|
||||
return
|
||||
|
||||
return credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
def _prepare_request_params(
|
||||
self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepares the request parameters for the API call.
|
||||
|
||||
Args:
|
||||
parameters: A list of ApiParameter objects representing the parameters
|
||||
for the API call.
|
||||
kwargs: The keyword arguments passed to the call function from the Tool
|
||||
caller.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the request parameters for the API call. This
|
||||
initializes a requests.request() call.
|
||||
|
||||
Example:
|
||||
self._prepare_request_params({"input_id": "test-id"})
|
||||
"""
|
||||
method = self.endpoint.method.lower()
|
||||
if not method:
|
||||
raise ValueError("Operation method not found.")
|
||||
|
||||
path_params: Dict[str, Any] = {}
|
||||
query_params: Dict[str, Any] = {}
|
||||
header_params: Dict[str, Any] = {}
|
||||
cookie_params: Dict[str, Any] = {}
|
||||
|
||||
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
|
||||
|
||||
# Fill in path, query, header and cookie parameters to the request
|
||||
for param_k, v in kwargs.items():
|
||||
param_obj = params_map.get(param_k)
|
||||
if not param_obj:
|
||||
continue # If input arg not in the ApiParameter list, ignore it.
|
||||
|
||||
original_k = param_obj.original_name
|
||||
param_location = param_obj.param_location
|
||||
|
||||
if param_location == "path":
|
||||
path_params[original_k] = v
|
||||
elif param_location == "query":
|
||||
if v:
|
||||
query_params[original_k] = v
|
||||
elif param_location == "header":
|
||||
header_params[original_k] = v
|
||||
elif param_location == "cookie":
|
||||
cookie_params[original_k] = v
|
||||
|
||||
# Construct URL
|
||||
base_url = self.endpoint.base_url or ""
|
||||
base_url = base_url[:-1] if base_url.endswith("/") else base_url
|
||||
url = f"{base_url}{self.endpoint.path.format(**path_params)}"
|
||||
|
||||
# Construct body
|
||||
body_kwargs: Dict[str, Any] = {}
|
||||
request_body = self.operation.requestBody
|
||||
if request_body:
|
||||
for mime_type, media_type_object in request_body.content.items():
|
||||
schema = media_type_object.schema_
|
||||
body_data = None
|
||||
|
||||
if schema.type == "object":
|
||||
body_data = {}
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name in kwargs:
|
||||
body_data[param.original_name] = kwargs[param.py_name]
|
||||
|
||||
elif schema.type == "array":
|
||||
for param in parameters:
|
||||
if param.param_location == "body" and param.py_name == "array":
|
||||
body_data = kwargs.get("array")
|
||||
break
|
||||
else: # like string
|
||||
for param in parameters:
|
||||
# original_name = '' indicating this param applies to the full body.
|
||||
if param.param_location == "body" and not param.original_name:
|
||||
body_data = (
|
||||
kwargs.get(param.py_name) if param.py_name in kwargs else None
|
||||
)
|
||||
break
|
||||
|
||||
if mime_type == "application/json" or mime_type.endswith("+json"):
|
||||
if body_data is not None:
|
||||
body_kwargs["json"] = body_data
|
||||
elif mime_type == "application/x-www-form-urlencoded":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "multipart/form-data":
|
||||
body_kwargs["files"] = body_data
|
||||
elif mime_type == "application/octet-stream":
|
||||
body_kwargs["data"] = body_data
|
||||
elif mime_type == "text/plain":
|
||||
body_kwargs["data"] = body_data
|
||||
|
||||
if mime_type:
|
||||
header_params["Content-Type"] = mime_type
|
||||
break # Process only the first mime_type
|
||||
|
||||
filtered_query_params: Dict[str, Any] = {
|
||||
k: v for k, v in query_params.items() if v is not None
|
||||
}
|
||||
|
||||
request_params: Dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"params": filtered_query_params,
|
||||
"headers": header_params,
|
||||
"cookies": cookie_params,
|
||||
**body_kwargs,
|
||||
}
|
||||
|
||||
return request_params
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
return self.call(args=args, tool_context=tool_context)
|
||||
|
||||
def call(
|
||||
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||
) -> Dict[str, Any]:
|
||||
"""Executes the REST API call.
|
||||
|
||||
Args:
|
||||
args: Keyword arguments representing the operation parameters.
|
||||
tool_context: The tool context (not used here, but required by the
|
||||
interface).
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary.
|
||||
"""
|
||||
# Prepare auth credentials for the API call
|
||||
tool_auth_handler = ToolAuthHandler.from_tool_context(
|
||||
tool_context, self.auth_scheme, self.auth_credential
|
||||
)
|
||||
auth_result = tool_auth_handler.prepare_auth_credentials()
|
||||
auth_state, auth_scheme, auth_credential = (
|
||||
auth_result.state,
|
||||
auth_result.auth_scheme,
|
||||
auth_result.auth_credential,
|
||||
)
|
||||
|
||||
if auth_state == "pending":
|
||||
return {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
# Attach parameters from auth into main parameters list
|
||||
api_params, api_args = self._operation_parser.get_parameters().copy(), args
|
||||
if auth_credential:
|
||||
# Attach parameters from auth into main parameters list
|
||||
auth_param, auth_args = self._prepare_auth_request_params(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
if auth_param and auth_args:
|
||||
api_params = [auth_param] + api_params
|
||||
api_args.update(auth_args)
|
||||
|
||||
# Got all parameters. Call the API.
|
||||
request_params = self._prepare_request_params(api_params, api_args)
|
||||
response = requests.request(**request_params)
|
||||
|
||||
# Parse API response
|
||||
try:
|
||||
response.raise_for_status() # Raise HTTPError for bad responses
|
||||
return response.json() # Try to decode JSON
|
||||
except requests.exceptions.HTTPError:
|
||||
error_details = response.content.decode("utf-8")
|
||||
return {
|
||||
"error": (
|
||||
f"Tool {self.name} execution failed. Analyze this execution error"
|
||||
" and your inputs. Retry with adjustments if applicable. But"
|
||||
" make sure don't retry more than 3 times. Execution Error:"
|
||||
f" {error_details}"
|
||||
)
|
||||
}
|
||||
except ValueError:
|
||||
return {"text": response.text} # Return text if not JSON
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}")'
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'RestApiTool(name="{self.name}", description="{self.description}",'
|
||||
f' endpoint="{self.endpoint}", operation="{self.operation}",'
|
||||
f' auth_scheme="{self.auth_scheme}",'
|
||||
f' auth_credential="{self.auth_credential}")'
|
||||
)
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_credential import AuthCredentialTypes
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....auth.auth_schemes import AuthSchemeType
|
||||
from ....auth.auth_tool import AuthConfig
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AuthPreparationState = Literal["pending", "done"]
|
||||
|
||||
|
||||
class AuthPreparationResult(BaseModel):
|
||||
"""Result of the credential preparation process."""
|
||||
|
||||
state: AuthPreparationState
|
||||
auth_scheme: Optional[AuthScheme] = None
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
|
||||
|
||||
class ToolContextCredentialStore:
|
||||
"""Handles storage and retrieval of credentials within a ToolContext."""
|
||||
|
||||
def __init__(self, tool_context: ToolContext):
|
||||
self.tool_context = tool_context
|
||||
|
||||
def get_credential_key(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
# no need to prepend temp: namespace, session state is a copy, changes to
|
||||
# it won't be persisted , only changes in event_action.state_delta will be
|
||||
# persisted. temp: namespace will be cleared after current run. but tool
|
||||
# want access token to be there stored across runs
|
||||
|
||||
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
|
||||
|
||||
def get_credential(
|
||||
self,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
) -> Optional[AuthCredential]:
|
||||
if not self.tool_context:
|
||||
return None
|
||||
|
||||
token_key = self.get_credential_key(auth_scheme, auth_credential)
|
||||
# TODO try not to use session state, this looks a hacky way, depend on
|
||||
# session implementation, we don't want session to persist the token,
|
||||
# meanwhile we want the token shared across runs.
|
||||
serialized_credential = self.tool_context.state.get(token_key)
|
||||
if not serialized_credential:
|
||||
return None
|
||||
return AuthCredential.model_validate(serialized_credential)
|
||||
|
||||
def store_credential(
|
||||
self,
|
||||
key: str,
|
||||
auth_credential: Optional[AuthCredential],
|
||||
):
|
||||
if self.tool_context:
|
||||
serializable_credential = jsonable_encoder(
|
||||
auth_credential, exclude_none=True
|
||||
)
|
||||
self.tool_context.state[key] = serializable_credential
|
||||
|
||||
def remove_credential(self, key: str):
|
||||
del self.tool_context.state[key]
|
||||
|
||||
|
||||
class ToolAuthHandler:
|
||||
"""Handles the preparation and exchange of authentication credentials for tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
credential_store: Optional["ToolContextCredentialStore"] = None,
|
||||
):
|
||||
self.tool_context = tool_context
|
||||
self.auth_scheme = (
|
||||
auth_scheme.model_copy(deep=True) if auth_scheme else None
|
||||
)
|
||||
self.auth_credential = (
|
||||
auth_credential.model_copy(deep=True) if auth_credential else None
|
||||
)
|
||||
self.credential_exchanger = (
|
||||
credential_exchanger or AutoAuthCredentialExchanger()
|
||||
)
|
||||
self.credential_store = credential_store
|
||||
self.should_store_credential = True
|
||||
|
||||
@classmethod
|
||||
def from_tool_context(
|
||||
cls,
|
||||
tool_context: ToolContext,
|
||||
auth_scheme: Optional[AuthScheme],
|
||||
auth_credential: Optional[AuthCredential],
|
||||
credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
|
||||
) -> "ToolAuthHandler":
|
||||
"""Creates a ToolAuthHandler instance from a ToolContext."""
|
||||
credential_store = ToolContextCredentialStore(tool_context)
|
||||
return cls(
|
||||
tool_context,
|
||||
auth_scheme,
|
||||
auth_credential,
|
||||
credential_exchanger,
|
||||
credential_store,
|
||||
)
|
||||
|
||||
def _handle_existing_credential(
|
||||
self,
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Checks for and returns an existing, exchanged credential."""
|
||||
if self.credential_store:
|
||||
existing_credential = self.credential_store.get_credential(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
if existing_credential:
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=existing_credential,
|
||||
)
|
||||
return None
|
||||
|
||||
def _exchange_credential(
|
||||
self, auth_credential: AuthCredential
|
||||
) -> Optional[AuthPreparationResult]:
|
||||
"""Handles an OpenID Connect authorization response."""
|
||||
|
||||
exchanged_credential = None
|
||||
try:
|
||||
exchanged_credential = self.credential_exchanger.exchange_credential(
|
||||
self.auth_scheme, auth_credential
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to exchange credential: %s", e)
|
||||
return exchanged_credential
|
||||
|
||||
def _store_credential(self, auth_credential: AuthCredential) -> None:
|
||||
"""stores the auth_credential."""
|
||||
|
||||
if self.credential_store:
|
||||
key = self.credential_store.get_credential_key(
|
||||
self.auth_scheme, self.auth_credential
|
||||
)
|
||||
self.credential_store.store_credential(key, auth_credential)
|
||||
|
||||
def _reqeust_credential(self) -> None:
|
||||
"""Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
|
||||
if self.auth_scheme.type_ in (
|
||||
AuthSchemeType.openIdConnect,
|
||||
AuthSchemeType.oauth2,
|
||||
):
|
||||
if not self.auth_credential or not self.auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"auth_credential is empty for scheme {self.auth_scheme.type_}."
|
||||
"Please create AuthCredential using OAuth2Auth."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_id:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_id is missing."
|
||||
)
|
||||
|
||||
if not self.auth_credential.oauth2.client_secret:
|
||||
raise AuthCredentialMissingError(
|
||||
"OAuth2 credentials client_secret is missing."
|
||||
)
|
||||
|
||||
self.tool_context.request_credential(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_auth_response(self) -> AuthCredential:
|
||||
return self.tool_context.get_auth_response(
|
||||
AuthConfig(
|
||||
auth_scheme=self.auth_scheme,
|
||||
raw_auth_credential=self.auth_credential,
|
||||
)
|
||||
)
|
||||
|
||||
def _request_credential(self, auth_config: AuthConfig):
|
||||
if not self.tool_context:
|
||||
return
|
||||
self.tool_context.request_credential(auth_config)
|
||||
|
||||
def prepare_auth_credentials(
|
||||
self,
|
||||
) -> AuthPreparationResult:
|
||||
"""Prepares authentication credentials, handling exchange and user interaction."""
|
||||
|
||||
# no auth is needed
|
||||
if not self.auth_scheme:
|
||||
return AuthPreparationResult(state="done")
|
||||
|
||||
# Check for existing credential.
|
||||
existing_result = self._handle_existing_credential()
|
||||
if existing_result:
|
||||
return existing_result
|
||||
|
||||
# fetch credential from adk framework
|
||||
# Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
|
||||
# multi-step exchange:
|
||||
# client_id , client_secret -> auth_uri -> auth_code -> access_token
|
||||
# -> bearer token
|
||||
# adk framework supports exchange access_token already
|
||||
fetched_credential = self._get_auth_response() or self.auth_credential
|
||||
|
||||
exchanged_credential = self._exchange_credential(fetched_credential)
|
||||
|
||||
if exchanged_credential:
|
||||
self._store_credential(exchanged_credential)
|
||||
return AuthPreparationResult(
|
||||
state="done",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=exchanged_credential,
|
||||
)
|
||||
else:
|
||||
self._reqeust_credential()
|
||||
return AuthPreparationResult(
|
||||
state="pending",
|
||||
auth_scheme=self.auth_scheme,
|
||||
auth_credential=self.auth_credential,
|
||||
)
|
||||
Reference in New Issue
Block a user