Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,628 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
import pytest
def create_minimal_openapi_spec() -> Dict[str, Any]:
"""Creates a minimal valid OpenAPI spec."""
return {
"openapi": "3.1.0",
"info": {"title": "Minimal API", "version": "1.0.0"},
"paths": {
"/test": {
"get": {
"summary": "Test GET endpoint",
"operationId": "testGet",
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {"schema": {"type": "string"}}
},
}
},
}
}
},
}
@pytest.fixture
def openapi_spec_generator():
"""Fixture for creating an OperationGenerator instance."""
return OpenApiSpecParser()
def test_parse_minimal_spec(openapi_spec_generator):
"""Test parsing a minimal OpenAPI specification."""
openapi_spec = create_minimal_openapi_spec()
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.name == "test_get"
assert op.endpoint.path == "/test"
assert op.endpoint.method == "get"
assert op.return_value.type_value == str
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
"""Test parsing a spec where operationId is missing (auto-generation)."""
openapi_spec = create_minimal_openapi_spec()
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
# Check if operationId is auto generated based on path and method.
assert parsed_operations[0].name == "test_get"
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
"""Test parsing a spec with multiple HTTP methods for the same path."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["post"] = {
"summary": "Test POST endpoint",
"operationId": "testPost",
"responses": {"200": {"description": "Successful response"}},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
operation_names = {op.name for op in parsed_operations}
assert len(parsed_operations) == 2
assert "test_get" in operation_names
assert "test_post" in operation_names
def test_parse_spec_with_parameters(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
{"name": "param1", "in": "query", "schema": {"type": "string"}},
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations[0].parameters) == 2
assert parsed_operations[0].parameters[0].original_name == "param1"
assert parsed_operations[0].parameters[0].param_location == "query"
assert parsed_operations[0].parameters[1].original_name == "param2"
assert parsed_operations[0].parameters[1].param_location == "header"
def test_parse_spec_with_request_body(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["post"] = {
"summary": "Endpoint with request body",
"operationId": "testPostWithBody",
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
}
},
"responses": {"200": {"description": "OK"}},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
post_operations = [
op for op in parsed_operations if op.endpoint.method == "post"
]
op = post_operations[0]
assert len(post_operations) == 1
assert op.name == "test_post_with_body"
assert len(op.parameters) == 1
assert op.parameters[0].original_name == "name"
assert op.parameters[0].type_value == str
def test_parse_spec_with_reference(openapi_spec_generator):
"""Test parsing a specification with $ref."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "API with Refs", "version": "1.0.0"},
"paths": {
"/test_ref": {
"get": {
"summary": "Endpoint with ref",
"operationId": "testGetRef",
"responses": {
"200": {
"description": "Success",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MySchema"
}
}
},
}
},
}
}
},
"components": {
"schemas": {
"MySchema": {
"type": "object",
"properties": {"name": {"type": "string"}},
}
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.return_value.type_value.__origin__ is dict
def test_parse_spec_with_circular_reference(openapi_spec_generator):
"""Test correct handling of circular $ref (important!)."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Circular Ref API", "version": "1.0.0"},
"paths": {
"/circular": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/A"}
}
},
}
}
}
}
},
"components": {
"schemas": {
"A": {
"type": "object",
"properties": {"b": {"$ref": "#/components/schemas/B"}},
},
"B": {
"type": "object",
"properties": {"a": {"$ref": "#/components/schemas/A"}},
},
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
op = parsed_operations[0]
assert op.return_value.type_value.__origin__ is dict
assert op.return_value.type_hint == "Dict[str, Any]"
def test_parse_no_paths(openapi_spec_generator):
"""Test with a spec that has no paths defined."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "No Paths API", "version": "1.0.0"},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 0 # Should be empty
def test_parse_empty_path_item(openapi_spec_generator):
"""Test a path item that is present but empty."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
"paths": {"/empty": None},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 0
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
"""Test parsing with a global security scheme."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["security"] = [{"api_key": []}]
openapi_spec["components"] = {
"securitySchemes": {
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
}
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert len(parsed_operations) == 1
assert op.auth_scheme is not None
assert op.auth_scheme.type_.value == "apiKey"
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
"""Test parsing with a local (operation-level) security scheme."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
openapi_spec["components"] = {
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
op = parsed_operations[0]
assert op.auth_scheme is not None
assert op.auth_scheme.type_.value == "http"
assert op.auth_scheme.scheme == "bearer"
def test_parse_spec_with_servers(openapi_spec_generator):
"""Test parsing with server URLs."""
openapi_spec = create_minimal_openapi_spec()
openapi_spec["servers"] = [
{"url": "https://api.example.com"},
{"url": "http://localhost:8000"},
]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
def test_parse_spec_with_no_servers(openapi_spec_generator):
"""Test with no servers defined (should default to empty string)."""
openapi_spec = create_minimal_openapi_spec()
if "servers" in openapi_spec:
del openapi_spec["servers"]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].endpoint.base_url == ""
def test_parse_spec_with_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
expected_description = "This is a test description."
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].description == expected_description
def test_parse_spec_with_empty_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
openapi_spec["paths"]["/test"]["get"]["description"] = ""
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert parsed_operations[0].description == ""
def test_parse_spec_with_no_description(openapi_spec_generator):
openapi_spec = create_minimal_openapi_spec()
# delete description
if "description" in openapi_spec["paths"]["/test"]["get"]:
del openapi_spec["paths"]["/test"]["get"]["description"]
if "summary" in openapi_spec["paths"]["/test"]["get"]:
del openapi_spec["paths"]["/test"]["get"]["summary"]
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
assert (
parsed_operations[0].description == ""
) # it should be initialized with empty string
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
"""Test that passing a non-dict object to parse raises TypeError"""
with pytest.raises(AttributeError):
openapi_spec_generator.parse(123) # type: ignore
with pytest.raises(AttributeError):
openapi_spec_generator.parse("openapi_spec") # type: ignore
with pytest.raises(AttributeError):
openapi_spec_generator.parse([]) # type: ignore
def test_parse_external_ref_raises_error(openapi_spec_generator):
"""Check that external references (not starting with #) raise ValueError."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "External Ref API", "version": "1.0.0"},
"paths": {
"/external": {
"get": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": (
"external_file.json#/components/schemas/ExternalSchema"
)
}
}
},
}
}
}
}
},
}
with pytest.raises(ValueError):
openapi_spec_generator.parse(openapi_spec)
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
"""Test specs with multiple paths, request/response bodies using deep refs."""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
"paths": {
"/path1": {
"post": {
"operationId": "postPath1",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Request1"
}
}
}
},
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response1"
}
}
},
}
},
}
},
"/path2": {
"put": {
"operationId": "putPath2",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Request2"
}
}
}
},
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response2"
}
}
},
}
},
},
"get": {
"operationId": "getPath2",
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Response2"
}
}
},
}
},
},
},
},
"components": {
"schemas": {
"Request1": {
"type": "object",
"properties": {
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
},
},
"Response1": {
"type": "object",
"properties": {
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
},
},
"Request2": {
"type": "object",
"properties": {
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
},
},
"Response2": {
"type": "object",
"properties": {
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
},
},
"Level1_1": {
"type": "object",
"properties": {
"level1_1_prop1": {
"$ref": "#/components/schemas/Level2_1"
}
},
},
"Level1_2": {
"type": "object",
"properties": {
"level1_2_prop1": {
"$ref": "#/components/schemas/Level2_2"
}
},
},
"Level2_1": {
"type": "object",
"properties": {
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
},
},
"Level2_2": {
"type": "object",
"properties": {"level2_2_prop1": {"type": "string"}},
},
"Level3": {"type": "integer"},
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 3
# Verify Path 1
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
assert len(path1_ops) == 1
path1_op = path1_ops[0]
assert path1_op.name == "post_path1"
assert len(path1_op.parameters) == 1
assert path1_op.parameters[0].original_name == "req1_prop1"
assert (
path1_op.parameters[0]
.param_schema.properties["level1_1_prop1"]
.properties["level2_1_prop1"]
.type
== "integer"
)
assert (
path1_op.return_value.param_schema.properties["res1_prop1"]
.properties["level1_2_prop1"]
.properties["level2_2_prop1"]
.type
== "string"
)
# Verify Path 2
path2_ops = [
op
for op in parsed_operations
if op.endpoint.path == "/path2" and op.name == "put_path2"
]
path2_op = path2_ops[0]
assert path2_op is not None
assert len(path2_op.parameters) == 1
assert path2_op.parameters[0].original_name == "req2_prop1"
assert (
path2_op.parameters[0]
.param_schema.properties["level1_1_prop1"]
.properties["level2_1_prop1"]
.type
== "integer"
)
assert (
path2_op.return_value.param_schema.properties["res2_prop1"]
.properties["level1_2_prop1"]
.properties["level2_2_prop1"]
.type
== "string"
)
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
"""Test handling of duplicate parameter names (one in query, one in body).
The expected behavior is that both parameters should be captured but with
different suffix, and
their `original_name` attributes should reflect their origin (query or body).
"""
openapi_spec = {
"openapi": "3.1.0",
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
"paths": {
"/duplicate": {
"post": {
"operationId": "createWithDuplicate",
"parameters": [{
"name": "name",
"in": "query",
"schema": {"type": "string"},
}],
"requestBody": {
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"name": {"type": "integer"}},
}
}
}
},
"responses": {"200": {"description": "OK"}},
}
}
},
}
parsed_operations = openapi_spec_generator.parse(openapi_spec)
assert len(parsed_operations) == 1
op = parsed_operations[0]
assert op.name == "create_with_duplicate"
assert len(op.parameters) == 2
query_param = None
body_param = None
for param in op.parameters:
if param.param_location == "query" and param.original_name == "name":
query_param = param
elif param.param_location == "body" and param.original_name == "name":
body_param = param
assert query_param is not None
assert query_param.original_name == "name"
assert query_param.py_name == "name"
assert body_param is not None
assert body_param.original_name == "name"
assert body_param.py_name == "name_0"

View File

@@ -0,0 +1,139 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import ParameterInType
from fastapi.openapi.models import SecuritySchemeType
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
import pytest
import yaml
def load_spec(file_path: str) -> Dict:
"""Loads the OpenAPI specification from a YAML file."""
with open(file_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
@pytest.fixture
def openapi_spec() -> Dict:
"""Fixture to load the OpenAPI specification."""
current_dir = os.path.dirname(os.path.abspath(__file__))
# Join the directory path with the filename
yaml_path = os.path.join(current_dir, "test.yaml")
return load_spec(yaml_path)
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a dictionary."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
assert isinstance(toolset.tools, list)
assert len(toolset.tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
"""Test initialization of OpenAPIToolset with a YAML string."""
spec_str = yaml.dump(openapi_spec)
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
assert isinstance(toolset.tools, list)
assert len(toolset.tools) == 5
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
"""Test the tool() method for an existing tool."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
tool_name = "calendar_calendars_insert" # Example operationId from the spec
tool = toolset.get_tool(tool_name)
assert isinstance(tool, RestApiTool)
assert tool.name == tool_name
assert tool.description == "Creates a secondary calendar."
assert tool.endpoint.method == "post"
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
assert tool.endpoint.path == "/calendars"
assert tool.is_long_running is False
assert tool.operation.operationId == "calendar.calendars.insert"
assert tool.operation.description == "Creates a secondary calendar."
assert isinstance(
tool.operation.requestBody.content["application/json"], MediaType
)
assert len(tool.operation.responses) == 1
response = tool.operation.responses["200"]
assert response.description == "Successful response"
assert isinstance(response.content["application/json"], MediaType)
assert isinstance(tool.auth_scheme, OAuth2)
tool_name = "calendar_calendars_get"
tool = toolset.get_tool(tool_name)
assert isinstance(tool, RestApiTool)
assert tool.name == tool_name
assert tool.description == "Returns metadata for a calendar."
assert tool.endpoint.method == "get"
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
assert tool.endpoint.path == "/calendars/{calendarId}"
assert tool.is_long_running is False
assert tool.operation.operationId == "calendar.calendars.get"
assert tool.operation.description == "Returns metadata for a calendar."
assert len(tool.operation.parameters) == 1
assert tool.operation.parameters[0].name == "calendarId"
assert tool.operation.parameters[0].in_ == ParameterInType.path
assert tool.operation.parameters[0].required is True
assert tool.operation.parameters[0].schema_.type == "string"
assert (
tool.operation.parameters[0].description
== "Calendar identifier. To retrieve calendar IDs call the"
" calendarList.list method. If you want to access the primary calendar"
' of the currently logged in user, use the "primary" keyword.'
)
assert isinstance(tool.auth_scheme, OAuth2)
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
"""Test the tool() method for a non-existing tool."""
toolset = OpenAPIToolset(spec_dict=openapi_spec)
tool = toolset.get_tool("non_existent_tool")
assert tool is None
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
"""Test configuring auth during initialization."""
auth_scheme = APIKey(**{
"in": APIKeyIn.header, # Use alias name in dict
"name": "api_key",
"type": SecuritySchemeType.http,
})
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
toolset = OpenAPIToolset(
spec_dict=openapi_spec,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
for tool in toolset.tools:
assert tool.auth_scheme == auth_scheme
assert tool.auth_credential == auth_credential

View File

@@ -0,0 +1,406 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Response
from fastapi.openapi.models import Schema
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
import pytest
@pytest.fixture
def sample_operation() -> Operation:
"""Fixture to provide a sample OpenAPI Operation object."""
return Operation(
operationId='test_operation',
summary='Test Summary',
description='Test Description',
parameters=[
Parameter(**{
'name': 'param1',
'in': 'query',
'schema': Schema(type='string'),
'description': 'Parameter 1',
}),
Parameter(**{
'name': 'param2',
'in': 'header',
'schema': Schema(type='string'),
'description': 'Parameter 2',
}),
],
requestBody=RequestBody(
content={
'application/json': MediaType(
schema=Schema(
type='object',
properties={
'prop1': Schema(
type='string', description='Property 1'
),
'prop2': Schema(
type='integer', description='Property 2'
),
},
)
)
},
description='Request body description',
),
responses={
'200': Response(
description='Success',
content={
'application/json': MediaType(schema=Schema(type='string'))
},
),
'400': Response(description='Client Error'),
},
security=[{'oauth2': ['resource: read', 'resource: write']}],
)
def test_operation_parser_initialization(sample_operation):
"""Test initialization of OperationParser."""
parser = OperationParser(sample_operation)
assert parser.operation == sample_operation
assert len(parser.params) == 4 # 2 params + 2 request body props
assert parser.return_value is not None
def test_process_operation_parameters(sample_operation):
"""Test _process_operation_parameters method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_operation_parameters()
assert len(parser.params) == 2
assert parser.params[0].original_name == 'param1'
assert parser.params[0].param_location == 'query'
assert parser.params[1].original_name == 'param2'
assert parser.params[1].param_location == 'header'
def test_process_request_body(sample_operation):
"""Test _process_request_body method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 2 # 2 properties in request body
assert parser.params[0].original_name == 'prop1'
assert parser.params[0].param_location == 'body'
assert parser.params[1].original_name == 'prop2'
assert parser.params[1].param_location == 'body'
def test_process_request_body_array():
"""Test _process_request_body method with array schema."""
operation = Operation(
requestBody=RequestBody(
content={
'application/json': MediaType(
schema=Schema(
type='array',
items=Schema(
type='object',
properties={
'item_prop1': Schema(
type='string', description='Item Property 1'
),
'item_prop2': Schema(
type='integer', description='Item Property 2'
),
},
),
)
)
}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 1
assert parser.params[0].original_name == 'array'
assert parser.params[0].param_location == 'body'
# Check that schema is correctly propagated and is a dictionary
assert parser.params[0].param_schema.type == 'array'
assert parser.params[0].param_schema.items.type == 'object'
assert 'item_prop1' in parser.params[0].param_schema.items.properties
assert 'item_prop2' in parser.params[0].param_schema.items.properties
assert (
parser.params[0].param_schema.items.properties['item_prop1'].description
== 'Item Property 1'
)
assert (
parser.params[0].param_schema.items.properties['item_prop2'].description
== 'Item Property 2'
)
def test_process_request_body_no_name():
"""Test _process_request_body with a schema that has no properties (unnamed)"""
operation = Operation(
requestBody=RequestBody(
content={'application/json': MediaType(schema=Schema(type='string'))}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 1
assert parser.params[0].original_name == '' # No name
assert parser.params[0].param_location == 'body'
def test_dedupe_param_names(sample_operation):
"""Test _dedupe_param_names method."""
parser = OperationParser(sample_operation, should_parse=False)
# Add duplicate named parameters.
parser.params = [
ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}),
ApiParameter(original_name='test', param_location='', param_schema={}),
]
parser._dedupe_param_names()
assert parser.params[0].py_name == 'test'
assert parser.params[1].py_name == 'test_0'
assert parser.params[2].py_name == 'test_1'
def test_process_return_value(sample_operation):
"""Test _process_return_value method."""
parser = OperationParser(sample_operation, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
assert parser.return_value.type_hint == 'str'
def test_process_return_value_no_2xx(sample_operation):
"""Tests _process_return_value when no 2xx response exists."""
operation_no_2xx = Operation(
responses={'400': Response(description='Client Error')}
)
parser = OperationParser(operation_no_2xx, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
assert parser.return_value.type_hint == 'Any'
def test_process_return_value_multiple_2xx(sample_operation):
"""Tests _process_return_value when multiple 2xx responses exist."""
operation_multi_2xx = Operation(
responses={
'201': Response(
description='Success',
content={
'application/json': MediaType(schema=Schema(type='integer'))
},
),
'202': Response(
description='Success',
content={'text/plain': MediaType(schema=Schema(type='string'))},
),
'200': Response(
description='Success',
content={
'application/pdf': MediaType(schema=Schema(type='boolean'))
},
),
'400': Response(
description='Failure',
content={
'application/xml': MediaType(schema=Schema(type='object'))
},
),
}
)
parser = OperationParser(operation_multi_2xx, should_parse=False)
parser._process_return_value()
assert parser.return_value is not None
# Take the content type of the 200 response since it's the smallest response
# code
assert parser.return_value.param_schema.type == 'boolean'
def test_process_return_value_no_content(sample_operation):
"""Test when 2xx response has no content"""
operation_no_content = Operation(
responses={'200': Response(description='Success', content={})}
)
parser = OperationParser(operation_no_content, should_parse=False)
parser._process_return_value()
assert parser.return_value.type_hint == 'Any'
def test_process_return_value_no_schema(sample_operation):
"""Tests when the 2xx response's content has no schema."""
operation_no_schema = Operation(
responses={
'200': Response(
description='Success',
content={'application/json': MediaType(schema=None)},
)
}
)
parser = OperationParser(operation_no_schema, should_parse=False)
parser._process_return_value()
assert parser.return_value.type_hint == 'Any'
def test_get_function_name(sample_operation):
"""Test get_function_name method."""
parser = OperationParser(sample_operation)
assert parser.get_function_name() == 'test_operation'
def test_get_function_name_missing_id():
"""Tests get_function_name when operationId is missing"""
operation = Operation() # No ID
parser = OperationParser(operation)
with pytest.raises(ValueError, match='Operation ID is missing'):
parser.get_function_name()
def test_get_return_type_hint(sample_operation):
"""Test get_return_type_hint method."""
parser = OperationParser(sample_operation)
assert parser.get_return_type_hint() == 'str'
def test_get_return_type_value(sample_operation):
"""Test get_return_type_value method."""
parser = OperationParser(sample_operation)
assert parser.get_return_type_value() == str
def test_get_parameters(sample_operation):
"""Test get_parameters method."""
parser = OperationParser(sample_operation)
params = parser.get_parameters()
assert len(params) == 4 # Correct count after processing
assert all(isinstance(p, ApiParameter) for p in params)
def test_get_return_value(sample_operation):
"""Test get_return_value method."""
parser = OperationParser(sample_operation)
return_value = parser.get_return_value()
assert isinstance(return_value, ApiParameter)
def test_get_auth_scheme_name(sample_operation):
"""Test get_auth_scheme_name method."""
parser = OperationParser(sample_operation)
assert parser.get_auth_scheme_name() == 'oauth2'
def test_get_auth_scheme_name_no_security():
"""Test get_auth_scheme_name when no security is present."""
operation = Operation(responses={})
parser = OperationParser(operation)
assert parser.get_auth_scheme_name() == ''
def test_get_pydoc_string(sample_operation):
"""Test get_pydoc_string method."""
parser = OperationParser(sample_operation)
pydoc_string = parser.get_pydoc_string()
assert 'Test Summary' in pydoc_string
assert 'Args:' in pydoc_string
assert 'param1 (str): Parameter 1' in pydoc_string
assert 'prop1 (str): Property 1' in pydoc_string
assert 'Returns (str):' in pydoc_string
assert 'Success' in pydoc_string
def test_get_json_schema(sample_operation):
"""Test get_json_schema method."""
parser = OperationParser(sample_operation)
json_schema = parser.get_json_schema()
assert json_schema['title'] == 'test_operation_Arguments'
assert json_schema['type'] == 'object'
assert 'param1' in json_schema['properties']
assert 'prop1' in json_schema['properties']
assert 'param1' in json_schema['required']
assert 'prop1' in json_schema['required']
def test_get_signature_parameters(sample_operation):
"""Test get_signature_parameters method."""
parser = OperationParser(sample_operation)
signature_params = parser.get_signature_parameters()
assert len(signature_params) == 4
assert signature_params[0].name == 'param1'
assert signature_params[0].annotation == str
assert signature_params[2].name == 'prop1'
assert signature_params[2].annotation == str
def test_get_annotations(sample_operation):
"""Test get_annotations method."""
parser = OperationParser(sample_operation)
annotations = parser.get_annotations()
assert len(annotations) == 5 # 4 parameters + return
assert annotations['param1'] == str
assert annotations['prop1'] == str
assert annotations['return'] == str
def test_load():
"""Test the load classmethod."""
operation = Operation(operationId='my_op') # Minimal operation
params = [
ApiParameter(
original_name='p1',
param_location='',
param_schema={'type': 'integer'},
)
]
return_value = ApiParameter(
original_name='', param_location='', param_schema={'type': 'string'}
)
parser = OperationParser.load(operation, params, return_value)
assert isinstance(parser, OperationParser)
assert parser.operation == operation
assert parser.params == params
assert parser.return_value == return_value
assert (
parser.get_function_name() == 'my_op'
) # Check that the operation is loaded
def test_operation_parser_with_dict():
"""Test initialization of OperationParser with a dictionary."""
operation_dict = {
'operationId': 'test_dict_operation',
'parameters': [
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
],
'responses': {
'200': {
'description': 'Dict Success',
'content': {'application/json': {'schema': {'type': 'string'}}},
}
},
}
parser = OperationParser(operation_dict)
assert parser.operation.operationId == 'test_dict_operation'
assert len(parser.params) == 1
assert parser.params[0].original_name == 'dict_param'
assert parser.return_value.type_hint == 'str'

View File

@@ -0,0 +1,966 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest.mock import MagicMock
from unittest.mock import patch
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema
from google.adk.sessions.state import State
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Type
import pytest
class TestRestApiTool:
@pytest.fixture
def mock_tool_context(self):
"""Fixture for a mock OperationParser."""
mock_context = MagicMock(spec=ToolContext)
mock_context.state = State({}, {})
mock_context.get_auth_response.return_value = {}
mock_context.request_credential.return_value = {}
return mock_context
@pytest.fixture
def mock_operation_parser(self):
"""Fixture for a mock OperationParser."""
mock_parser = MagicMock(spec=OperationParser)
mock_parser.get_function_name.return_value = "mock_function_name"
mock_parser.get_json_schema.return_value = {}
mock_parser.get_parameters.return_value = []
mock_parser.get_return_type_hint.return_value = "str"
mock_parser.get_pydoc_string.return_value = "Mock docstring"
mock_parser.get_signature_parameters.return_value = []
mock_parser.get_return_type_value.return_value = str
mock_parser.get_annotations.return_value = {}
return mock_parser
@pytest.fixture
def sample_endpiont(self):
return OperationEndpoint(
base_url="https://example.com", path="/test", method="GET"
)
@pytest.fixture
def sample_operation(self):
return Operation(
operationId="testOperation",
description="Test operation",
parameters=[],
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"testBodyParam": OpenAPISchema(type="string")
},
)
)
}
),
)
@pytest.fixture
def sample_api_parameters(self):
return [
ApiParameter(
original_name="test_param",
py_name="test_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
is_required=True,
),
ApiParameter(
original_name="",
py_name="test_body_param",
param_location="body",
param_schema=OpenAPISchema(type="string"),
is_required=True,
),
]
@pytest.fixture
def sample_return_parameter(self):
return ApiParameter(
original_name="test_param",
py_name="test_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
is_required=True,
)
@pytest.fixture
def sample_auth_scheme(self):
scheme, _ = token_to_scheme_credential(
"apikey", "header", "", "sample_auth_credential_internal_test"
)
return scheme
@pytest.fixture
def sample_auth_credential(self):
_, credential = token_to_scheme_credential(
"apikey", "header", "", "sample_auth_credential_internal_test"
)
return credential
def test_init(
self,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
assert tool.name == "test_tool"
assert tool.description == "Test Tool"
assert tool.endpoint == sample_endpiont
assert tool.operation == sample_operation
assert tool.auth_credential == sample_auth_credential
assert tool.auth_scheme == sample_auth_scheme
assert tool.credential_exchanger is not None
def test_from_parsed_operation_str(
self,
sample_endpiont,
sample_api_parameters,
sample_return_parameter,
sample_operation,
):
parsed_operation_str = json.dumps({
"name": "test_operation",
"description": "Test Description",
"endpoint": sample_endpiont.model_dump(),
"operation": sample_operation.model_dump(),
"auth_scheme": None,
"auth_credential": None,
"parameters": [p.model_dump() for p in sample_api_parameters],
"return_value": sample_return_parameter.model_dump(),
})
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
assert tool.name == "test_operation"
def test_get_declaration(
self, sample_endpiont, sample_operation, mock_operation_parser
):
tool = RestApiTool(
name="test_tool",
description="Test description",
endpoint=sample_endpiont,
operation=sample_operation,
should_parse_operation=False,
)
tool._operation_parser = mock_operation_parser
declaration = tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_tool"
assert declaration.description == "Test description"
assert isinstance(declaration.parameters, Schema)
@patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
)
def test_call_success(
self,
mock_request,
mock_tool_context,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
mock_response = MagicMock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
# Call the method
result = tool.call(args={}, tool_context=mock_tool_context)
# Check the result
assert result == {"result": "success"}
@patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
)
def test_call_auth_pending(
self,
mock_request,
sample_endpiont,
sample_operation,
sample_auth_scheme,
sample_auth_credential,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_scheme=sample_auth_scheme,
auth_credential=sample_auth_credential,
)
with patch(
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
) as mock_from_tool_context:
mock_tool_auth_handler_instance = MagicMock()
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
"pending"
)
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
response = tool.call(args={}, tool_context=None)
assert response == {
"pending": True,
"message": "Needs your authorization to access your data.",
}
def test_prepare_request_params_query_body(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
# Create a mock Operation object
mock_operation = Operation(
operationId="test_op",
parameters=[
OpenAPIParameter(**{
"name": "testQueryParam",
"in": "query",
"schema": OpenAPISchema(type="string"),
})
],
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"param1": OpenAPISchema(type="string"),
"param2": OpenAPISchema(type="integer"),
},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="param1",
py_name="param1",
param_location="body",
param_schema=OpenAPISchema(type="string"),
),
ApiParameter(
original_name="param2",
py_name="param2",
param_location="body",
param_schema=OpenAPISchema(type="integer"),
),
ApiParameter(
original_name="testQueryParam",
py_name="test_query_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
),
]
kwargs = {
"param1": "value1",
"param2": 123,
"test_query_param": "query_value",
}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["method"] == "get"
assert request_params["url"] == "https://example.com/test"
assert request_params["json"] == {"param1": "value1", "param2": 123}
assert request_params["params"] == {"testQueryParam": "query_value"}
def test_prepare_request_params_array(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(
type="array", items=OpenAPISchema(type="string")
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="array", # Match the parameter name
py_name="array",
param_location="body",
param_schema=OpenAPISchema(
type="array", items=OpenAPISchema(type="string")
),
)
]
kwargs = {"array": ["item1", "item2"]}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["json"] == ["item1", "item2"]
def test_prepare_request_params_string(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
}
),
)
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="input_string",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"input_string": "test_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == "test_value"
assert request_params["headers"]["Content-Type"] == "text/plain"
def test_prepare_request_params_form_data(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/x-www-form-urlencoded": MediaType(
schema=OpenAPISchema(
type="object",
properties={"key1": OpenAPISchema(type="string")},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="key1",
py_name="key1",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"key1": "value1"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == {"key1": "value1"}
assert (
request_params["headers"]["Content-Type"]
== "application/x-www-form-urlencoded"
)
def test_prepare_request_params_multipart(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"multipart/form-data": MediaType(
schema=OpenAPISchema(
type="object",
properties={
"file1": OpenAPISchema(
type="string", format="binary"
)
},
)
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="file1",
py_name="file1",
param_location="body",
param_schema=OpenAPISchema(type="string", format="binary"),
)
]
kwargs = {"file1": b"file_content"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["files"] == {"file1": b"file_content"}
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
def test_prepare_request_params_octet_stream(
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
):
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/octet-stream": MediaType(
schema=OpenAPISchema(type="string", format="binary")
)
}
),
)
tool = RestApiTool(
name="test_tool",
description="test",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="data",
param_location="body",
param_schema=OpenAPISchema(type="string", format="binary"),
)
]
kwargs = {"data": b"binary_data"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["data"] == b"binary_data"
assert (
request_params["headers"]["Content-Type"] == "application/octet-stream"
)
def test_prepare_request_params_path_param(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
mock_operation = Operation(operationId="test_op")
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="user_id",
py_name="user_id",
param_location="path",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"user_id": "123"}
endpoint_with_path = OperationEndpoint(
base_url="https://example.com", path="/test/{user_id}", method="get"
)
tool.endpoint = endpoint_with_path
request_params = tool._prepare_request_params(params, kwargs)
assert (
request_params["url"] == "https://example.com/test/123"
) # Path param replaced
def test_prepare_request_params_header_param(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="X-Custom-Header",
py_name="x_custom_header",
param_location="header",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"x_custom_header": "header_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["headers"]["X-Custom-Header"] == "header_value"
def test_prepare_request_params_cookie_param(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="session_id",
py_name="session_id",
param_location="cookie",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"session_id": "cookie_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["cookies"]["session_id"] == "cookie_value"
def test_prepare_request_params_multiple_mime_types(
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
):
# Test what happens when multiple mime types are specified. It should take
# the first one.
mock_operation = Operation(
operationId="test_op",
requestBody=RequestBody(
content={
"application/json": MediaType(
schema=OpenAPISchema(type="string")
),
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
}
),
)
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=mock_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="",
py_name="input",
param_location="body",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"input": "some_value"}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["headers"]["Content-Type"] == "application/json"
def test_prepare_request_params_unknown_parameter(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="known_param",
py_name="known_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"known_param": "value", "unknown_param": "unknown"}
request_params = tool._prepare_request_params(params, kwargs)
# Make sure unknown parameters are ignored and do not raise errors.
assert "unknown_param" not in request_params["params"]
def test_prepare_request_params_base_url_handling(
self, sample_auth_credential, sample_auth_scheme, sample_operation
):
# No base_url provided, should use path as is
tool_no_base = RestApiTool(
name="test_tool_no_base",
description="Test Tool",
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = []
kwargs = {}
request_params_no_base = tool_no_base._prepare_request_params(
params, kwargs
)
assert request_params_no_base["url"] == "/no_base"
tool_trailing_slash = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=OperationEndpoint(
base_url="https://example.com/", path="/trailing", method="get"
),
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
request_params_trailing = tool_trailing_slash._prepare_request_params(
params, kwargs
)
assert request_params_trailing["url"] == "https://example.com/trailing"
def test_prepare_request_params_no_unrecognized_query_parameter(
self,
sample_endpiont,
sample_auth_credential,
sample_auth_scheme,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=sample_auth_credential,
auth_scheme=sample_auth_scheme,
)
params = [
ApiParameter(
original_name="unrecognized_param",
py_name="unrecognized_param",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"unrecognized_param": None} # Explicitly passing None
request_params = tool._prepare_request_params(params, kwargs)
# Query param not in sample_operation. It should be ignored.
assert "unrecognized_param" not in request_params["params"]
def test_prepare_request_params_no_credential(
self,
sample_endpiont,
sample_operation,
):
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpiont,
operation=sample_operation,
auth_credential=None,
auth_scheme=None,
)
params = [
ApiParameter(
original_name="param_name",
py_name="param_name",
param_location="query",
param_schema=OpenAPISchema(type="string"),
)
]
kwargs = {"param_name": "aaa", "empty_param": ""}
request_params = tool._prepare_request_params(params, kwargs)
assert "param_name" in request_params["params"]
assert "empty_param" not in request_params["params"]
class TestToGeminiSchema:
def test_to_gemini_schema_none(self):
assert to_gemini_schema(None) is None
def test_to_gemini_schema_not_dict(self):
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
to_gemini_schema("not a dict")
def test_to_gemini_schema_empty_dict(self):
result = to_gemini_schema({})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
def test_to_gemini_schema_dict_with_only_object_type(self):
result = to_gemini_schema({"type": "object"})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
def test_to_gemini_schema_basic_types(self):
openapi_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"is_active": {"type": "boolean"},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert isinstance(gemini_schema, Schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["name"].type == Type.STRING
assert gemini_schema.properties["age"].type == Type.INTEGER
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
def test_to_gemini_schema_nested_objects(self):
openapi_schema = {
"type": "object",
"properties": {
"address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
},
}
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.properties["address"].type == Type.OBJECT
assert (
gemini_schema.properties["address"].properties["street"].type
== Type.STRING
)
assert (
gemini_schema.properties["address"].properties["city"].type
== Type.STRING
)
def test_to_gemini_schema_array(self):
openapi_schema = {
"type": "array",
"items": {"type": "string"},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.ARRAY
assert gemini_schema.items.type == Type.STRING
def test_to_gemini_schema_nested_array(self):
openapi_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.items.properties["name"].type == Type.STRING
def test_to_gemini_schema_any_of(self):
openapi_schema = {
"anyOf": [{"type": "string"}, {"type": "integer"}],
}
gemini_schema = to_gemini_schema(openapi_schema)
assert len(gemini_schema.any_of) == 2
assert gemini_schema.any_of[0].type == Type.STRING
assert gemini_schema.any_of[1].type == Type.INTEGER
def test_to_gemini_schema_general_list(self):
openapi_schema = {
"type": "array",
"properties": {
"list_field": {"type": "array", "items": {"type": "string"}},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.properties["list_field"].type == Type.ARRAY
assert gemini_schema.properties["list_field"].items.type == Type.STRING
def test_to_gemini_schema_enum(self):
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.enum == ["a", "b", "c"]
def test_to_gemini_schema_required(self):
openapi_schema = {
"type": "object",
"required": ["name"],
"properties": {"name": {"type": "string"}},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.required == ["name"]
def test_to_gemini_schema_nested_dict(self):
openapi_schema = {
"type": "object",
"properties": {"metadata": {"key1": "value1", "key2": 123}},
}
gemini_schema = to_gemini_schema(openapi_schema)
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
assert isinstance(gemini_schema.properties["metadata"], Schema)
assert (
gemini_schema.properties["metadata"].type == Type.OBJECT
) # add object type by default
assert gemini_schema.properties["metadata"].properties == {
"dummy_DO_NOT_GENERATE": Schema(type="string")
}
def test_to_gemini_schema_ignore_title_default_format(self):
openapi_schema = {
"type": "string",
"title": "Test Title",
"default": "default_value",
"format": "date",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.title is None
assert gemini_schema.default is None
assert gemini_schema.format is None
def test_to_gemini_schema_property_ordering(self):
openapi_schema = {
"type": "object",
"propertyOrdering": ["name", "age"],
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.property_ordering == ["name", "age"]
def test_to_gemini_schema_converts_property_dict(self):
openapi_schema = {
"properties": {
"name": {"type": "string", "description": "The property key"},
"value": {"type": "string", "description": "The property value"},
},
"type": "object",
"description": "A single property entry in the Properties message.",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["name"].type == Type.STRING
assert gemini_schema.properties["value"].type == Type.STRING
def test_to_gemini_schema_remove_unrecognized_fields(self):
openapi_schema = {
"type": "string",
"description": "A single date string.",
"format": "date",
}
gemini_schema = to_gemini_schema(openapi_schema)
assert gemini_schema.type == Type.STRING
assert not gemini_schema.format
def test_snake_to_lower_camel():
assert snake_to_lower_camel("single") == "single"
assert snake_to_lower_camel("two_words") == "twoWords"
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
assert not snake_to_lower_camel("")
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"

View File

@@ -0,0 +1,201 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
from google.adk.tools.tool_context import ToolContext
import pytest
# Helper function to create a mock ToolContext
def create_mock_tool_context():
return ToolContext(
function_call_id='test-fc-id',
invocation_context=InvocationContext(
agent=LlmAgent(name='test'),
session=Session(app_name='test', user_id='123', id='123'),
invocation_id='123',
session_service=InMemorySessionService(),
),
)
# Test cases for OpenID Connect
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
def __init__(
self, expected_scheme, expected_credential, expected_access_token
):
self.expected_scheme = expected_scheme
self.expected_credential = expected_credential
self.expected_access_token = expected_access_token
def exchange_credential(
self,
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
if auth_credential.oauth2 and (
auth_credential.oauth2.auth_response_uri
or auth_credential.oauth2.auth_code
):
auth_code = (
auth_credential.oauth2.auth_response_uri
if auth_credential.oauth2.auth_response_uri
else auth_credential.oauth2.auth_code
)
# Simulate the token exchange
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme='bearer',
credentials=HttpCredentials(
token=auth_code + self.expected_access_token
),
),
)
return updated_credential
# simulate the case of getting auth_uri
return None
def get_mock_openid_scheme_credential():
config_dict = {
'authorization_endpoint': 'test.com',
'token_endpoint': 'test.com',
}
scopes = ['test_scope']
credential_dict = {
'client_id': '123',
'client_secret': '456',
'redirect_uri': 'test.com',
}
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
# Fixture for the OpenID Connect security scheme
@pytest.fixture
def openid_connect_scheme():
scheme, _ = get_mock_openid_scheme_credential()
return scheme
# Fixture for a base OpenID Connect credential
@pytest.fixture
def openid_connect_credential():
_, credential = get_mock_openid_scheme_credential()
return credential
def test_openid_connect_no_auth_response(
openid_connect_scheme, openid_connect_credential
):
# Setup Mock exchanger
mock_exchanger = MockOpenIdConnectCredentialExchanger(
openid_connect_scheme, openid_connect_credential, None
)
tool_context = create_mock_tool_context()
credential_store = ToolContextCredentialStore(tool_context=tool_context)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_exchanger=mock_exchanger,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'pending'
assert result.auth_credential == openid_connect_credential
def test_openid_connect_with_auth_response(
openid_connect_scheme, openid_connect_credential, monkeypatch
):
mock_exchanger = MockOpenIdConnectCredentialExchanger(
openid_connect_scheme,
openid_connect_credential,
'test_access_token',
)
tool_context = create_mock_tool_context()
mock_auth_handler = MagicMock()
mock_auth_handler.get_auth_response.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
)
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
monkeypatch.setattr(
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
)
credential_store = ToolContextCredentialStore(tool_context=tool_context)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_exchanger=mock_exchanger,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'done'
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
assert 'test_access_token' in result.auth_credential.http.credentials.token
# Verify that the credential was stored:
stored_credential = credential_store.get_credential(
openid_connect_scheme, openid_connect_credential
)
assert stored_credential == result.auth_credential
mock_auth_handler.get_auth_response.assert_called_once()
def test_openid_connect_existing_token(
openid_connect_scheme, openid_connect_credential
):
_, existing_credential = token_to_scheme_credential(
'oauth2Token', 'header', 'bearer', '123123123'
)
tool_context = create_mock_tool_context()
# Store the credential to simulate existing credential
credential_store = ToolContextCredentialStore(tool_context=tool_context)
key = credential_store.get_credential_key(
openid_connect_scheme, openid_connect_credential
)
credential_store.store_credential(key, existing_credential)
handler = ToolAuthHandler(
tool_context,
openid_connect_scheme,
openid_connect_credential,
credential_store=credential_store,
)
result = handler.prepare_auth_credentials()
assert result.state == 'done'
assert result.auth_credential == existing_credential