mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-19 03:42:22 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
import pytest
|
||||
|
||||
|
||||
def create_minimal_openapi_spec() -> Dict[str, Any]:
|
||||
"""Creates a minimal valid OpenAPI spec."""
|
||||
return {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Minimal API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "Test GET endpoint",
|
||||
"operationId": "testGet",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {"schema": {"type": "string"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec_generator():
|
||||
"""Fixture for creating an OperationGenerator instance."""
|
||||
return OpenApiSpecParser()
|
||||
|
||||
|
||||
def test_parse_minimal_spec(openapi_spec_generator):
|
||||
"""Test parsing a minimal OpenAPI specification."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.name == "test_get"
|
||||
assert op.endpoint.path == "/test"
|
||||
assert op.endpoint.method == "get"
|
||||
assert op.return_value.type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
|
||||
"""Test parsing a spec where operationId is missing (auto-generation)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
# Check if operationId is auto generated based on path and method.
|
||||
assert parsed_operations[0].name == "test_get"
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
|
||||
"""Test parsing a spec with multiple HTTP methods for the same path."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Test POST endpoint",
|
||||
"operationId": "testPost",
|
||||
"responses": {"200": {"description": "Successful response"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
operation_names = {op.name for op in parsed_operations}
|
||||
|
||||
assert len(parsed_operations) == 2
|
||||
assert "test_get" in operation_names
|
||||
assert "test_post" in operation_names
|
||||
|
||||
|
||||
def test_parse_spec_with_parameters(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
|
||||
{"name": "param1", "in": "query", "schema": {"type": "string"}},
|
||||
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations[0].parameters) == 2
|
||||
assert parsed_operations[0].parameters[0].original_name == "param1"
|
||||
assert parsed_operations[0].parameters[0].param_location == "query"
|
||||
assert parsed_operations[0].parameters[1].original_name == "param2"
|
||||
assert parsed_operations[0].parameters[1].param_location == "header"
|
||||
|
||||
|
||||
def test_parse_spec_with_request_body(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Endpoint with request body",
|
||||
"operationId": "testPostWithBody",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
post_operations = [
|
||||
op for op in parsed_operations if op.endpoint.method == "post"
|
||||
]
|
||||
op = post_operations[0]
|
||||
|
||||
assert len(post_operations) == 1
|
||||
assert op.name == "test_post_with_body"
|
||||
assert len(op.parameters) == 1
|
||||
assert op.parameters[0].original_name == "name"
|
||||
assert op.parameters[0].type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_reference(openapi_spec_generator):
|
||||
"""Test parsing a specification with $ref."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "API with Refs", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test_ref": {
|
||||
"get": {
|
||||
"summary": "Endpoint with ref",
|
||||
"operationId": "testGetRef",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/MySchema"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"MySchema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
|
||||
|
||||
def test_parse_spec_with_circular_reference(openapi_spec_generator):
|
||||
"""Test correct handling of circular $ref (important!)."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Circular Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/circular": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/A"}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"A": {
|
||||
"type": "object",
|
||||
"properties": {"b": {"$ref": "#/components/schemas/B"}},
|
||||
},
|
||||
"B": {
|
||||
"type": "object",
|
||||
"properties": {"a": {"$ref": "#/components/schemas/A"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
|
||||
op = parsed_operations[0]
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
assert op.return_value.type_hint == "Dict[str, Any]"
|
||||
|
||||
|
||||
def test_parse_no_paths(openapi_spec_generator):
|
||||
"""Test with a spec that has no paths defined."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "No Paths API", "version": "1.0.0"},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_parse_empty_path_item(openapi_spec_generator):
|
||||
"""Test a path item that is present but empty."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
|
||||
"paths": {"/empty": None},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 0
|
||||
|
||||
|
||||
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a global security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["security"] = [{"api_key": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {
|
||||
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "apiKey"
|
||||
|
||||
|
||||
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a local (operation-level) security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "http"
|
||||
assert op.auth_scheme.scheme == "bearer"
|
||||
|
||||
|
||||
def test_parse_spec_with_servers(openapi_spec_generator):
|
||||
"""Test parsing with server URLs."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["servers"] = [
|
||||
{"url": "https://api.example.com"},
|
||||
{"url": "http://localhost:8000"},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
|
||||
|
||||
|
||||
def test_parse_spec_with_no_servers(openapi_spec_generator):
|
||||
"""Test with no servers defined (should default to empty string)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
if "servers" in openapi_spec:
|
||||
del openapi_spec["servers"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
expected_description = "This is a test description."
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == expected_description
|
||||
|
||||
|
||||
def test_parse_spec_with_empty_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = ""
|
||||
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_no_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
# delete description
|
||||
if "description" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["description"]
|
||||
if "summary" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["summary"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert (
|
||||
parsed_operations[0].description == ""
|
||||
) # it should be initialized with empty string
|
||||
|
||||
|
||||
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
|
||||
"""Test that passing a non-dict object to parse raises TypeError"""
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse(123) # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse("openapi_spec") # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse([]) # type: ignore
|
||||
|
||||
|
||||
def test_parse_external_ref_raises_error(openapi_spec_generator):
|
||||
"""Check that external references (not starting with #) raise ValueError."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "External Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/external": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"external_file.json#/components/schemas/ExternalSchema"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
|
||||
"""Test specs with multiple paths, request/response bodies using deep refs."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/path1": {
|
||||
"post": {
|
||||
"operationId": "postPath1",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request1"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response1"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/path2": {
|
||||
"put": {
|
||||
"operationId": "putPath2",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request2"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"get": {
|
||||
"operationId": "getPath2",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Request1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Request2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Level1_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_1_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_1"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level1_2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_2_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_2"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level2_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
|
||||
},
|
||||
},
|
||||
"Level2_2": {
|
||||
"type": "object",
|
||||
"properties": {"level2_2_prop1": {"type": "string"}},
|
||||
},
|
||||
"Level3": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 3
|
||||
|
||||
# Verify Path 1
|
||||
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
|
||||
assert len(path1_ops) == 1
|
||||
path1_op = path1_ops[0]
|
||||
assert path1_op.name == "post_path1"
|
||||
|
||||
assert len(path1_op.parameters) == 1
|
||||
assert path1_op.parameters[0].original_name == "req1_prop1"
|
||||
assert (
|
||||
path1_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path1_op.return_value.param_schema.properties["res1_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
# Verify Path 2
|
||||
path2_ops = [
|
||||
op
|
||||
for op in parsed_operations
|
||||
if op.endpoint.path == "/path2" and op.name == "put_path2"
|
||||
]
|
||||
path2_op = path2_ops[0]
|
||||
assert path2_op is not None
|
||||
assert len(path2_op.parameters) == 1
|
||||
assert path2_op.parameters[0].original_name == "req2_prop1"
|
||||
assert (
|
||||
path2_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path2_op.return_value.param_schema.properties["res2_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
|
||||
"""Test handling of duplicate parameter names (one in query, one in body).
|
||||
|
||||
The expected behavior is that both parameters should be captured but with
|
||||
different suffix, and
|
||||
their `original_name` attributes should reflect their origin (query or body).
|
||||
"""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/duplicate": {
|
||||
"post": {
|
||||
"operationId": "createWithDuplicate",
|
||||
"parameters": [{
|
||||
"name": "name",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
}],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "integer"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
op = parsed_operations[0]
|
||||
assert op.name == "create_with_duplicate"
|
||||
assert len(op.parameters) == 2
|
||||
|
||||
query_param = None
|
||||
body_param = None
|
||||
for param in op.parameters:
|
||||
if param.param_location == "query" and param.original_name == "name":
|
||||
query_param = param
|
||||
elif param.param_location == "body" and param.original_name == "name":
|
||||
body_param = param
|
||||
|
||||
assert query_param is not None
|
||||
assert query_param.original_name == "name"
|
||||
assert query_param.py_name == "name"
|
||||
|
||||
assert body_param is not None
|
||||
assert body_param.original_name == "name"
|
||||
assert body_param.py_name == "name_0"
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import ParameterInType
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
def load_spec(file_path: str) -> Dict:
|
||||
"""Loads the OpenAPI specification from a YAML file."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec() -> Dict:
|
||||
"""Fixture to load the OpenAPI specification."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(current_dir, "test.yaml")
|
||||
return load_spec(yaml_path)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a dictionary."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a YAML string."""
|
||||
spec_str = yaml.dump(openapi_spec)
|
||||
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for an existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool_name = "calendar_calendars_insert" # Example operationId from the spec
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Creates a secondary calendar."
|
||||
assert tool.endpoint.method == "post"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.insert"
|
||||
assert tool.operation.description == "Creates a secondary calendar."
|
||||
assert isinstance(
|
||||
tool.operation.requestBody.content["application/json"], MediaType
|
||||
)
|
||||
assert len(tool.operation.responses) == 1
|
||||
response = tool.operation.responses["200"]
|
||||
assert response.description == "Successful response"
|
||||
assert isinstance(response.content["application/json"], MediaType)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
tool_name = "calendar_calendars_get"
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Returns metadata for a calendar."
|
||||
assert tool.endpoint.method == "get"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars/{calendarId}"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.get"
|
||||
assert tool.operation.description == "Returns metadata for a calendar."
|
||||
assert len(tool.operation.parameters) == 1
|
||||
assert tool.operation.parameters[0].name == "calendarId"
|
||||
assert tool.operation.parameters[0].in_ == ParameterInType.path
|
||||
assert tool.operation.parameters[0].required is True
|
||||
assert tool.operation.parameters[0].schema_.type == "string"
|
||||
assert (
|
||||
tool.operation.parameters[0].description
|
||||
== "Calendar identifier. To retrieve calendar IDs call the"
|
||||
" calendarList.list method. If you want to access the primary calendar"
|
||||
' of the currently logged in user, use the "primary" keyword.'
|
||||
)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for a non-existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool = toolset.get_tool("non_existent_tool")
|
||||
assert tool is None
|
||||
|
||||
|
||||
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
|
||||
"""Test configuring auth during initialization."""
|
||||
|
||||
auth_scheme = APIKey(**{
|
||||
"in": APIKeyIn.header, # Use alias name in dict
|
||||
"name": "api_key",
|
||||
"type": SecuritySchemeType.http,
|
||||
})
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=openapi_spec,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
for tool in toolset.tools:
|
||||
assert tool.auth_scheme == auth_scheme
|
||||
assert tool.auth_credential == auth_credential
|
||||
@@ -0,0 +1,406 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation() -> Operation:
|
||||
"""Fixture to provide a sample OpenAPI Operation object."""
|
||||
return Operation(
|
||||
operationId='test_operation',
|
||||
summary='Test Summary',
|
||||
description='Test Description',
|
||||
parameters=[
|
||||
Parameter(**{
|
||||
'name': 'param1',
|
||||
'in': 'query',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 1',
|
||||
}),
|
||||
Parameter(**{
|
||||
'name': 'param2',
|
||||
'in': 'header',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 2',
|
||||
}),
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': Schema(
|
||||
type='string', description='Property 1'
|
||||
),
|
||||
'prop2': Schema(
|
||||
type='integer', description='Property 2'
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
},
|
||||
description='Request body description',
|
||||
),
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='string'))
|
||||
},
|
||||
),
|
||||
'400': Response(description='Client Error'),
|
||||
},
|
||||
security=[{'oauth2': ['resource: read', 'resource: write']}],
|
||||
)
|
||||
|
||||
|
||||
def test_operation_parser_initialization(sample_operation):
|
||||
"""Test initialization of OperationParser."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.operation == sample_operation
|
||||
assert len(parser.params) == 4 # 2 params + 2 request body props
|
||||
assert parser.return_value is not None
|
||||
|
||||
|
||||
def test_process_operation_parameters(sample_operation):
|
||||
"""Test _process_operation_parameters method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_operation_parameters()
|
||||
assert len(parser.params) == 2
|
||||
assert parser.params[0].original_name == 'param1'
|
||||
assert parser.params[0].param_location == 'query'
|
||||
assert parser.params[1].original_name == 'param2'
|
||||
assert parser.params[1].param_location == 'header'
|
||||
|
||||
|
||||
def test_process_request_body(sample_operation):
|
||||
"""Test _process_request_body method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 2 # 2 properties in request body
|
||||
assert parser.params[0].original_name == 'prop1'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
assert parser.params[1].original_name == 'prop2'
|
||||
assert parser.params[1].param_location == 'body'
|
||||
|
||||
|
||||
def test_process_request_body_array():
|
||||
"""Test _process_request_body method with array schema."""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='array',
|
||||
items=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'item_prop1': Schema(
|
||||
type='string', description='Item Property 1'
|
||||
),
|
||||
'item_prop2': Schema(
|
||||
type='integer', description='Item Property 2'
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'array'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
# Check that schema is correctly propagated and is a dictionary
|
||||
assert parser.params[0].param_schema.type == 'array'
|
||||
assert parser.params[0].param_schema.items.type == 'object'
|
||||
assert 'item_prop1' in parser.params[0].param_schema.items.properties
|
||||
assert 'item_prop2' in parser.params[0].param_schema.items.properties
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop1'].description
|
||||
== 'Item Property 1'
|
||||
)
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop2'].description
|
||||
== 'Item Property 2'
|
||||
)
|
||||
|
||||
|
||||
def test_process_request_body_no_name():
|
||||
"""Test _process_request_body with a schema that has no properties (unnamed)"""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={'application/json': MediaType(schema=Schema(type='string'))}
|
||||
)
|
||||
)
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == '' # No name
|
||||
assert parser.params[0].param_location == 'body'
|
||||
|
||||
|
||||
def test_dedupe_param_names(sample_operation):
|
||||
"""Test _dedupe_param_names method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
# Add duplicate named parameters.
|
||||
parser.params = [
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
]
|
||||
parser._dedupe_param_names()
|
||||
assert parser.params[0].py_name == 'test'
|
||||
assert parser.params[1].py_name == 'test_0'
|
||||
assert parser.params[2].py_name == 'test_1'
|
||||
|
||||
|
||||
def test_process_return_value(sample_operation):
|
||||
"""Test _process_return_value method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
|
||||
|
||||
def test_process_return_value_no_2xx(sample_operation):
|
||||
"""Tests _process_return_value when no 2xx response exists."""
|
||||
operation_no_2xx = Operation(
|
||||
responses={'400': Response(description='Client Error')}
|
||||
)
|
||||
parser = OperationParser(operation_no_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_multiple_2xx(sample_operation):
|
||||
"""Tests _process_return_value when multiple 2xx responses exist."""
|
||||
operation_multi_2xx = Operation(
|
||||
responses={
|
||||
'201': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='integer'))
|
||||
},
|
||||
),
|
||||
'202': Response(
|
||||
description='Success',
|
||||
content={'text/plain': MediaType(schema=Schema(type='string'))},
|
||||
),
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/pdf': MediaType(schema=Schema(type='boolean'))
|
||||
},
|
||||
),
|
||||
'400': Response(
|
||||
description='Failure',
|
||||
content={
|
||||
'application/xml': MediaType(schema=Schema(type='object'))
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
parser = OperationParser(operation_multi_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
|
||||
assert parser.return_value is not None
|
||||
# Take the content type of the 200 response since it's the smallest response
|
||||
# code
|
||||
assert parser.return_value.param_schema.type == 'boolean'
|
||||
|
||||
|
||||
def test_process_return_value_no_content(sample_operation):
|
||||
"""Test when 2xx response has no content"""
|
||||
operation_no_content = Operation(
|
||||
responses={'200': Response(description='Success', content={})}
|
||||
)
|
||||
parser = OperationParser(operation_no_content, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_no_schema(sample_operation):
|
||||
"""Tests when the 2xx response's content has no schema."""
|
||||
operation_no_schema = Operation(
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={'application/json': MediaType(schema=None)},
|
||||
)
|
||||
}
|
||||
)
|
||||
parser = OperationParser(operation_no_schema, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_get_function_name(sample_operation):
|
||||
"""Test get_function_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_function_name() == 'test_operation'
|
||||
|
||||
|
||||
def test_get_function_name_missing_id():
|
||||
"""Tests get_function_name when operationId is missing"""
|
||||
operation = Operation() # No ID
|
||||
parser = OperationParser(operation)
|
||||
with pytest.raises(ValueError, match='Operation ID is missing'):
|
||||
parser.get_function_name()
|
||||
|
||||
|
||||
def test_get_return_type_hint(sample_operation):
|
||||
"""Test get_return_type_hint method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_hint() == 'str'
|
||||
|
||||
|
||||
def test_get_return_type_value(sample_operation):
|
||||
"""Test get_return_type_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_value() == str
|
||||
|
||||
|
||||
def test_get_parameters(sample_operation):
|
||||
"""Test get_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
params = parser.get_parameters()
|
||||
assert len(params) == 4 # Correct count after processing
|
||||
assert all(isinstance(p, ApiParameter) for p in params)
|
||||
|
||||
|
||||
def test_get_return_value(sample_operation):
|
||||
"""Test get_return_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
return_value = parser.get_return_value()
|
||||
assert isinstance(return_value, ApiParameter)
|
||||
|
||||
|
||||
def test_get_auth_scheme_name(sample_operation):
|
||||
"""Test get_auth_scheme_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_auth_scheme_name() == 'oauth2'
|
||||
|
||||
|
||||
def test_get_auth_scheme_name_no_security():
|
||||
"""Test get_auth_scheme_name when no security is present."""
|
||||
operation = Operation(responses={})
|
||||
parser = OperationParser(operation)
|
||||
assert parser.get_auth_scheme_name() == ''
|
||||
|
||||
|
||||
def test_get_pydoc_string(sample_operation):
|
||||
"""Test get_pydoc_string method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
pydoc_string = parser.get_pydoc_string()
|
||||
assert 'Test Summary' in pydoc_string
|
||||
assert 'Args:' in pydoc_string
|
||||
assert 'param1 (str): Parameter 1' in pydoc_string
|
||||
assert 'prop1 (str): Property 1' in pydoc_string
|
||||
assert 'Returns (str):' in pydoc_string
|
||||
assert 'Success' in pydoc_string
|
||||
|
||||
|
||||
def test_get_json_schema(sample_operation):
|
||||
"""Test get_json_schema method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
json_schema = parser.get_json_schema()
|
||||
assert json_schema['title'] == 'test_operation_Arguments'
|
||||
assert json_schema['type'] == 'object'
|
||||
assert 'param1' in json_schema['properties']
|
||||
assert 'prop1' in json_schema['properties']
|
||||
assert 'param1' in json_schema['required']
|
||||
assert 'prop1' in json_schema['required']
|
||||
|
||||
|
||||
def test_get_signature_parameters(sample_operation):
|
||||
"""Test get_signature_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
signature_params = parser.get_signature_parameters()
|
||||
assert len(signature_params) == 4
|
||||
assert signature_params[0].name == 'param1'
|
||||
assert signature_params[0].annotation == str
|
||||
assert signature_params[2].name == 'prop1'
|
||||
assert signature_params[2].annotation == str
|
||||
|
||||
|
||||
def test_get_annotations(sample_operation):
|
||||
"""Test get_annotations method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
annotations = parser.get_annotations()
|
||||
assert len(annotations) == 5 # 4 parameters + return
|
||||
assert annotations['param1'] == str
|
||||
assert annotations['prop1'] == str
|
||||
assert annotations['return'] == str
|
||||
|
||||
|
||||
def test_load():
|
||||
"""Test the load classmethod."""
|
||||
operation = Operation(operationId='my_op') # Minimal operation
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name='p1',
|
||||
param_location='',
|
||||
param_schema={'type': 'integer'},
|
||||
)
|
||||
]
|
||||
return_value = ApiParameter(
|
||||
original_name='', param_location='', param_schema={'type': 'string'}
|
||||
)
|
||||
|
||||
parser = OperationParser.load(operation, params, return_value)
|
||||
|
||||
assert isinstance(parser, OperationParser)
|
||||
assert parser.operation == operation
|
||||
assert parser.params == params
|
||||
assert parser.return_value == return_value
|
||||
assert (
|
||||
parser.get_function_name() == 'my_op'
|
||||
) # Check that the operation is loaded
|
||||
|
||||
|
||||
def test_operation_parser_with_dict():
|
||||
"""Test initialization of OperationParser with a dictionary."""
|
||||
operation_dict = {
|
||||
'operationId': 'test_dict_operation',
|
||||
'parameters': [
|
||||
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
|
||||
],
|
||||
'responses': {
|
||||
'200': {
|
||||
'description': 'Dict Success',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
parser = OperationParser(operation_dict)
|
||||
assert parser.operation.operationId == 'test_dict_operation'
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'dict_param'
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
@@ -0,0 +1,966 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Schema as OpenAPISchema
|
||||
from google.adk.sessions.state import State
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRestApiTool:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_context = MagicMock(spec=ToolContext)
|
||||
mock_context.state = State({}, {})
|
||||
mock_context.get_auth_response.return_value = {}
|
||||
mock_context.request_credential.return_value = {}
|
||||
return mock_context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_operation_parser(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_parser = MagicMock(spec=OperationParser)
|
||||
mock_parser.get_function_name.return_value = "mock_function_name"
|
||||
mock_parser.get_json_schema.return_value = {}
|
||||
mock_parser.get_parameters.return_value = []
|
||||
mock_parser.get_return_type_hint.return_value = "str"
|
||||
mock_parser.get_pydoc_string.return_value = "Mock docstring"
|
||||
mock_parser.get_signature_parameters.return_value = []
|
||||
mock_parser.get_return_type_value.return_value = str
|
||||
mock_parser.get_annotations.return_value = {}
|
||||
return mock_parser
|
||||
|
||||
@pytest.fixture
|
||||
def sample_endpiont(self):
|
||||
return OperationEndpoint(
|
||||
base_url="https://example.com", path="/test", method="GET"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation(self):
|
||||
return Operation(
|
||||
operationId="testOperation",
|
||||
description="Test operation",
|
||||
parameters=[],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"testBodyParam": OpenAPISchema(type="string")
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_api_parameters(self):
|
||||
return [
|
||||
ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="test_body_param",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_return_parameter(self):
|
||||
return ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_scheme(self):
|
||||
scheme, _ = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return scheme
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_credential(self):
|
||||
_, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return credential
|
||||
|
||||
def test_init(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test Tool"
|
||||
assert tool.endpoint == sample_endpiont
|
||||
assert tool.operation == sample_operation
|
||||
assert tool.auth_credential == sample_auth_credential
|
||||
assert tool.auth_scheme == sample_auth_scheme
|
||||
assert tool.credential_exchanger is not None
|
||||
|
||||
def test_from_parsed_operation_str(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_api_parameters,
|
||||
sample_return_parameter,
|
||||
sample_operation,
|
||||
):
|
||||
parsed_operation_str = json.dumps({
|
||||
"name": "test_operation",
|
||||
"description": "Test Description",
|
||||
"endpoint": sample_endpiont.model_dump(),
|
||||
"operation": sample_operation.model_dump(),
|
||||
"auth_scheme": None,
|
||||
"auth_credential": None,
|
||||
"parameters": [p.model_dump() for p in sample_api_parameters],
|
||||
"return_value": sample_return_parameter.model_dump(),
|
||||
})
|
||||
|
||||
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
|
||||
assert tool.name == "test_operation"
|
||||
|
||||
def test_get_declaration(
|
||||
self, sample_endpiont, sample_operation, mock_operation_parser
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test description",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
should_parse_operation=False,
|
||||
)
|
||||
tool._operation_parser = mock_operation_parser
|
||||
|
||||
declaration = tool._get_declaration()
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
assert declaration.name == "test_tool"
|
||||
assert declaration.description == "Test description"
|
||||
assert isinstance(declaration.parameters, Schema)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_success(
|
||||
self,
|
||||
mock_request,
|
||||
mock_tool_context,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
# Check the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_auth_pending(
|
||||
self,
|
||||
mock_request,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
with patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"pending"
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
|
||||
response = tool.call(args={}, tool_context=None)
|
||||
assert response == {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
def test_prepare_request_params_query_body(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Create a mock Operation object
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
parameters=[
|
||||
OpenAPIParameter(**{
|
||||
"name": "testQueryParam",
|
||||
"in": "query",
|
||||
"schema": OpenAPISchema(type="string"),
|
||||
})
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"param1": OpenAPISchema(type="string"),
|
||||
"param2": OpenAPISchema(type="integer"),
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param1",
|
||||
py_name="param1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="param2",
|
||||
py_name="param2",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="integer"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="testQueryParam",
|
||||
py_name="test_query_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
]
|
||||
kwargs = {
|
||||
"param1": "value1",
|
||||
"param2": 123,
|
||||
"test_query_param": "query_value",
|
||||
}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
assert request_params["method"] == "get"
|
||||
assert request_params["url"] == "https://example.com/test"
|
||||
assert request_params["json"] == {"param1": "value1", "param2": 123}
|
||||
assert request_params["params"] == {"testQueryParam": "query_value"}
|
||||
|
||||
def test_prepare_request_params_array(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="array", # Match the parameter name
|
||||
py_name="array",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
),
|
||||
)
|
||||
]
|
||||
kwargs = {"array": ["item1", "item2"]}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["json"] == ["item1", "item2"]
|
||||
|
||||
def test_prepare_request_params_string(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input_string",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input_string": "test_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == "test_value"
|
||||
assert request_params["headers"]["Content-Type"] == "text/plain"
|
||||
|
||||
def test_prepare_request_params_form_data(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/x-www-form-urlencoded": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={"key1": OpenAPISchema(type="string")},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="key1",
|
||||
py_name="key1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"key1": "value1"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == {"key1": "value1"}
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"]
|
||||
== "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_multipart(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"multipart/form-data": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"file1": OpenAPISchema(
|
||||
type="string", format="binary"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="file1",
|
||||
py_name="file1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"file1": b"file_content"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["files"] == {"file1": b"file_content"}
|
||||
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
|
||||
|
||||
def test_prepare_request_params_octet_stream(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/octet-stream": MediaType(
|
||||
schema=OpenAPISchema(type="string", format="binary")
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="data",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"data": b"binary_data"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == b"binary_data"
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"] == "application/octet-stream"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_path_param(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(operationId="test_op")
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="user_id",
|
||||
py_name="user_id",
|
||||
param_location="path",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"user_id": "123"}
|
||||
endpoint_with_path = OperationEndpoint(
|
||||
base_url="https://example.com", path="/test/{user_id}", method="get"
|
||||
)
|
||||
tool.endpoint = endpoint_with_path
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert (
|
||||
request_params["url"] == "https://example.com/test/123"
|
||||
) # Path param replaced
|
||||
|
||||
def test_prepare_request_params_header_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="X-Custom-Header",
|
||||
py_name="x_custom_header",
|
||||
param_location="header",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"x_custom_header": "header_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["X-Custom-Header"] == "header_value"
|
||||
|
||||
def test_prepare_request_params_cookie_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="session_id",
|
||||
py_name="session_id",
|
||||
param_location="cookie",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"session_id": "cookie_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["cookies"]["session_id"] == "cookie_value"
|
||||
|
||||
def test_prepare_request_params_multiple_mime_types(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Test what happens when multiple mime types are specified. It should take
|
||||
# the first one.
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(type="string")
|
||||
),
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input": "some_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
def test_prepare_request_params_unknown_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="known_param",
|
||||
py_name="known_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"known_param": "value", "unknown_param": "unknown"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Make sure unknown parameters are ignored and do not raise errors.
|
||||
assert "unknown_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_base_url_handling(
|
||||
self, sample_auth_credential, sample_auth_scheme, sample_operation
|
||||
):
|
||||
# No base_url provided, should use path as is
|
||||
tool_no_base = RestApiTool(
|
||||
name="test_tool_no_base",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = []
|
||||
kwargs = {}
|
||||
|
||||
request_params_no_base = tool_no_base._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_no_base["url"] == "/no_base"
|
||||
|
||||
tool_trailing_slash = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(
|
||||
base_url="https://example.com/", path="/trailing", method="get"
|
||||
),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
request_params_trailing = tool_trailing_slash._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_trailing["url"] == "https://example.com/trailing"
|
||||
|
||||
def test_prepare_request_params_no_unrecognized_query_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="unrecognized_param",
|
||||
py_name="unrecognized_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"unrecognized_param": None} # Explicitly passing None
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Query param not in sample_operation. It should be ignored.
|
||||
assert "unrecognized_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_no_credential(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=None,
|
||||
auth_scheme=None,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param_name",
|
||||
py_name="param_name",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"param_name": "aaa", "empty_param": ""}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert "param_name" in request_params["params"]
|
||||
assert "empty_param" not in request_params["params"]
|
||||
|
||||
|
||||
class TestToGeminiSchema:
|
||||
|
||||
def test_to_gemini_schema_none(self):
|
||||
assert to_gemini_schema(None) is None
|
||||
|
||||
def test_to_gemini_schema_not_dict(self):
|
||||
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
|
||||
to_gemini_schema("not a dict")
|
||||
|
||||
def test_to_gemini_schema_empty_dict(self):
|
||||
result = to_gemini_schema({})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_dict_with_only_object_type(self):
|
||||
result = to_gemini_schema({"type": "object"})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_basic_types(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"is_active": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert isinstance(gemini_schema, Schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["age"].type == Type.INTEGER
|
||||
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
||||
|
||||
def test_to_gemini_schema_nested_objects(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["address"].type == Type.OBJECT
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["street"].type
|
||||
== Type.STRING
|
||||
)
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["city"].type
|
||||
== Type.STRING
|
||||
)
|
||||
|
||||
def test_to_gemini_schema_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.ARRAY
|
||||
assert gemini_schema.items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_nested_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.items.properties["name"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_any_of(self):
|
||||
openapi_schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert len(gemini_schema.any_of) == 2
|
||||
assert gemini_schema.any_of[0].type == Type.STRING
|
||||
assert gemini_schema.any_of[1].type == Type.INTEGER
|
||||
|
||||
def test_to_gemini_schema_general_list(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"properties": {
|
||||
"list_field": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["list_field"].type == Type.ARRAY
|
||||
assert gemini_schema.properties["list_field"].items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_enum(self):
|
||||
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.enum == ["a", "b", "c"]
|
||||
|
||||
def test_to_gemini_schema_required(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.required == ["name"]
|
||||
|
||||
def test_to_gemini_schema_nested_dict(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {"metadata": {"key1": "value1", "key2": 123}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
||||
assert isinstance(gemini_schema.properties["metadata"], Schema)
|
||||
assert (
|
||||
gemini_schema.properties["metadata"].type == Type.OBJECT
|
||||
) # add object type by default
|
||||
assert gemini_schema.properties["metadata"].properties == {
|
||||
"dummy_DO_NOT_GENERATE": Schema(type="string")
|
||||
}
|
||||
|
||||
def test_to_gemini_schema_ignore_title_default_format(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"title": "Test Title",
|
||||
"default": "default_value",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
|
||||
assert gemini_schema.title is None
|
||||
assert gemini_schema.default is None
|
||||
assert gemini_schema.format is None
|
||||
|
||||
def test_to_gemini_schema_property_ordering(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"propertyOrdering": ["name", "age"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.property_ordering == ["name", "age"]
|
||||
|
||||
def test_to_gemini_schema_converts_property_dict(self):
|
||||
openapi_schema = {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The property key"},
|
||||
"value": {"type": "string", "description": "The property value"},
|
||||
},
|
||||
"type": "object",
|
||||
"description": "A single property entry in the Properties message.",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["value"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_remove_unrecognized_fields(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"description": "A single date string.",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.STRING
|
||||
assert not gemini_schema.format
|
||||
|
||||
|
||||
def test_snake_to_lower_camel():
|
||||
assert snake_to_lower_camel("single") == "single"
|
||||
assert snake_to_lower_camel("two_words") == "twoWords"
|
||||
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
|
||||
assert not snake_to_lower_camel("")
|
||||
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
# Helper function to create a mock ToolContext
|
||||
def create_mock_tool_context():
|
||||
return ToolContext(
|
||||
function_call_id='test-fc-id',
|
||||
invocation_context=InvocationContext(
|
||||
agent=LlmAgent(name='test'),
|
||||
session=Session(app_name='test', user_id='123', id='123'),
|
||||
invocation_id='123',
|
||||
session_service=InMemorySessionService(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Test cases for OpenID Connect
|
||||
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
|
||||
|
||||
def __init__(
|
||||
self, expected_scheme, expected_credential, expected_access_token
|
||||
):
|
||||
self.expected_scheme = expected_scheme
|
||||
self.expected_credential = expected_credential
|
||||
self.expected_access_token = expected_access_token
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
if auth_credential.oauth2 and (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
or auth_credential.oauth2.auth_code
|
||||
):
|
||||
auth_code = (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
if auth_credential.oauth2.auth_response_uri
|
||||
else auth_credential.oauth2.auth_code
|
||||
)
|
||||
# Simulate the token exchange
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme='bearer',
|
||||
credentials=HttpCredentials(
|
||||
token=auth_code + self.expected_access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
# simulate the case of getting auth_uri
|
||||
return None
|
||||
|
||||
|
||||
def get_mock_openid_scheme_credential():
|
||||
config_dict = {
|
||||
'authorization_endpoint': 'test.com',
|
||||
'token_endpoint': 'test.com',
|
||||
}
|
||||
scopes = ['test_scope']
|
||||
credential_dict = {
|
||||
'client_id': '123',
|
||||
'client_secret': '456',
|
||||
'redirect_uri': 'test.com',
|
||||
}
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
# Fixture for the OpenID Connect security scheme
|
||||
@pytest.fixture
|
||||
def openid_connect_scheme():
|
||||
scheme, _ = get_mock_openid_scheme_credential()
|
||||
return scheme
|
||||
|
||||
|
||||
# Fixture for a base OpenID Connect credential
|
||||
@pytest.fixture
|
||||
def openid_connect_credential():
|
||||
_, credential = get_mock_openid_scheme_credential()
|
||||
return credential
|
||||
|
||||
|
||||
def test_openid_connect_no_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
# Setup Mock exchanger
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme, openid_connect_credential, None
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'pending'
|
||||
assert result.auth_credential == openid_connect_credential
|
||||
|
||||
|
||||
def test_openid_connect_with_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential, monkeypatch
|
||||
):
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
'test_access_token',
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
|
||||
mock_auth_handler = MagicMock()
|
||||
mock_auth_handler.get_auth_response.return_value = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
|
||||
)
|
||||
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
|
||||
monkeypatch.setattr(
|
||||
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
|
||||
)
|
||||
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert 'test_access_token' in result.auth_credential.http.credentials.token
|
||||
# Verify that the credential was stored:
|
||||
stored_credential = credential_store.get_credential(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
assert stored_credential == result.auth_credential
|
||||
mock_auth_handler.get_auth_response.assert_called_once()
|
||||
|
||||
|
||||
def test_openid_connect_existing_token(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
_, existing_credential = token_to_scheme_credential(
|
||||
'oauth2Token', 'header', 'bearer', '123123123'
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
# Store the credential to simulate existing credential
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
key = credential_store.get_credential_key(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
credential_store.store_credential(key, existing_credential)
|
||||
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential == existing_credential
|
||||
Reference in New Issue
Block a user