mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
14
tests/unittests/__init__.py
Normal file
14
tests/unittests/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
14
tests/unittests/agents/__init__.py
Normal file
14
tests/unittests/agents/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
407
tests/unittests/agents/test_base_agent.py
Normal file
407
tests/unittests/agents/test_base_agent.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Testings for the BaseAgent."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.events import Event
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _before_agent_callback_bypass_agent(
|
||||
callback_context: CallbackContext,
|
||||
) -> types.Content:
|
||||
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
|
||||
|
||||
|
||||
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _after_agent_callback_append_agent_reply(
|
||||
callback_context: CallbackContext,
|
||||
) -> types.Content:
|
||||
return types.Content(
|
||||
parts=[types.Part(text='Agent reply from after agent callback.')]
|
||||
)
|
||||
|
||||
|
||||
class _IncompleteAgent(BaseAgent):
|
||||
pass
|
||||
|
||||
|
||||
class _TestingAgent(BaseAgent):
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(parts=[types.Part(text='Hello, world!')]),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
branch=ctx.branch,
|
||||
content=types.Content(parts=[types.Part(text='Hello, live!')]),
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent, branch: Optional[str] = None
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id=f'{test_name}_invocation_id',
|
||||
branch=branch,
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_agent_name():
|
||||
with pytest.raises(ValueError):
|
||||
_ = _TestingAgent(name='not an identifier')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
events = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, world!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
events = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, world!'
|
||||
assert events[0].branch.endswith(agent.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_before_agent_callback_noop(
|
||||
request: pytest.FixtureRequest,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> Union[types.Content, None]:
|
||||
# Arrange
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
|
||||
|
||||
# Act
|
||||
_ = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
# Assert
|
||||
spy_before_agent_callback.assert_called_once()
|
||||
_, kwargs = spy_before_agent_callback.call_args
|
||||
assert 'callback_context' in kwargs
|
||||
assert isinstance(kwargs['callback_context'], CallbackContext)
|
||||
|
||||
spy_run_async_impl.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_before_agent_callback_bypass_agent(
|
||||
request: pytest.FixtureRequest,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
# Arrange
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
before_agent_callback=_before_agent_callback_bypass_agent,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
|
||||
|
||||
# Act
|
||||
events = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
# Assert
|
||||
spy_before_agent_callback.assert_called_once()
|
||||
spy_run_async_impl.assert_not_called()
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_after_agent_callback_noop(
|
||||
request: pytest.FixtureRequest,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
):
|
||||
# Arrange
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_noop,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||
|
||||
# Act
|
||||
events = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
# Assert
|
||||
spy_after_agent_callback.assert_called_once()
|
||||
_, kwargs = spy_after_agent_callback.call_args
|
||||
assert 'callback_context' in kwargs
|
||||
assert isinstance(kwargs['callback_context'], CallbackContext)
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_after_agent_callback_append_reply(
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
# Arrange
|
||||
agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
after_agent_callback=_after_agent_callback_append_agent_reply,
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
# Act
|
||||
events = [e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
# Assert
|
||||
assert len(events) == 2
|
||||
assert events[1].author == agent.name
|
||||
assert (
|
||||
events[1].content.parts[0].text
|
||||
== 'Agent reply from after agent callback.'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
[e async for e in agent.run_async(parent_ctx)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
events = [e async for e in agent.run_live(parent_ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, live!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_with_branch(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent, branch='parent_branch'
|
||||
)
|
||||
|
||||
events = [e async for e in agent.run_live(parent_ctx)]
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].author == agent.name
|
||||
assert events[0].content.parts[0].text == 'Hello, live!'
|
||||
assert events[0].branch.endswith(agent.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
|
||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, agent
|
||||
)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
[e async for e in agent.run_live(parent_ctx)]
|
||||
|
||||
|
||||
def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
|
||||
sub_agents: list[BaseAgent] = [
|
||||
_TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
|
||||
_TestingAgent(name=f'{request.function.__name__}_sub_agent_2'),
|
||||
]
|
||||
parent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent',
|
||||
sub_agents=sub_agents,
|
||||
)
|
||||
|
||||
for sub_agent in sub_agents:
|
||||
assert sub_agent.parent_agent == parent
|
||||
|
||||
|
||||
def test_find_agent(request: pytest.FixtureRequest):
|
||||
grand_sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_1'
|
||||
)
|
||||
grand_sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_2'
|
||||
)
|
||||
sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_1',
|
||||
sub_agents=[grand_sub_agent_1],
|
||||
)
|
||||
sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_2',
|
||||
sub_agents=[grand_sub_agent_2],
|
||||
)
|
||||
parent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent',
|
||||
sub_agents=[sub_agent_1, sub_agent_2],
|
||||
)
|
||||
|
||||
assert parent.find_agent(parent.name) == parent
|
||||
assert parent.find_agent(sub_agent_1.name) == sub_agent_1
|
||||
assert parent.find_agent(sub_agent_2.name) == sub_agent_2
|
||||
assert parent.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
|
||||
assert parent.find_agent(grand_sub_agent_2.name) == grand_sub_agent_2
|
||||
assert sub_agent_1.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
|
||||
assert sub_agent_1.find_agent(grand_sub_agent_2.name) is None
|
||||
assert sub_agent_2.find_agent(grand_sub_agent_1.name) is None
|
||||
assert sub_agent_2.find_agent(sub_agent_2.name) == sub_agent_2
|
||||
assert parent.find_agent('not_exist') is None
|
||||
|
||||
|
||||
def test_find_sub_agent(request: pytest.FixtureRequest):
|
||||
grand_sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_1'
|
||||
)
|
||||
grand_sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_2'
|
||||
)
|
||||
sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_1',
|
||||
sub_agents=[grand_sub_agent_1],
|
||||
)
|
||||
sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_2',
|
||||
sub_agents=[grand_sub_agent_2],
|
||||
)
|
||||
parent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent',
|
||||
sub_agents=[sub_agent_1, sub_agent_2],
|
||||
)
|
||||
|
||||
assert parent.find_sub_agent(sub_agent_1.name) == sub_agent_1
|
||||
assert parent.find_sub_agent(sub_agent_2.name) == sub_agent_2
|
||||
assert parent.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
|
||||
assert parent.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
|
||||
assert sub_agent_1.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
|
||||
assert sub_agent_1.find_sub_agent(grand_sub_agent_2.name) is None
|
||||
assert sub_agent_2.find_sub_agent(grand_sub_agent_1.name) is None
|
||||
assert sub_agent_2.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
|
||||
assert parent.find_sub_agent(parent.name) is None
|
||||
assert parent.find_sub_agent('not_exist') is None
|
||||
|
||||
|
||||
def test_root_agent(request: pytest.FixtureRequest):
|
||||
grand_sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_1'
|
||||
)
|
||||
grand_sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}__grand_sub_agent_2'
|
||||
)
|
||||
sub_agent_1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_1',
|
||||
sub_agents=[grand_sub_agent_1],
|
||||
)
|
||||
sub_agent_2 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_sub_agent_2',
|
||||
sub_agents=[grand_sub_agent_2],
|
||||
)
|
||||
parent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent',
|
||||
sub_agents=[sub_agent_1, sub_agent_2],
|
||||
)
|
||||
|
||||
assert parent.root_agent == parent
|
||||
assert sub_agent_1.root_agent == parent
|
||||
assert sub_agent_2.root_agent == parent
|
||||
assert grand_sub_agent_1.root_agent == parent
|
||||
assert grand_sub_agent_2.root_agent == parent
|
||||
|
||||
|
||||
def test_set_parent_agent_for_sub_agent_twice(
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
sub_agent = _TestingAgent(name=f'{request.function.__name__}_sub_agent')
|
||||
_ = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent_1',
|
||||
sub_agents=[sub_agent],
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
_ = _TestingAgent(
|
||||
name=f'{request.function.__name__}_parent_2',
|
||||
sub_agents=[sub_agent],
|
||||
)
|
||||
191
tests/unittests/agents/test_langgraph_agent.py
Normal file
191
tests/unittests/agents/test_langgraph_agent.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.langgraph_agent import LangGraphAgent
|
||||
from google.adk.events import Event
|
||||
from google.genai import types
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"checkpointer_value, events_list, expected_messages",
|
||||
[
|
||||
(
|
||||
MagicMock(),
|
||||
[
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="test prompt")],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="root_agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="(some delegation)")],
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
SystemMessage(content="test system prompt"),
|
||||
HumanMessage(content="test prompt"),
|
||||
],
|
||||
),
|
||||
(
|
||||
None,
|
||||
[
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="user prompt 1")],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="root_agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_text(text="root agent response")
|
||||
],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="weather_agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_text(text="weather agent response")
|
||||
],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="user prompt 2")],
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
SystemMessage(content="test system prompt"),
|
||||
HumanMessage(content="user prompt 1"),
|
||||
AIMessage(content="weather agent response"),
|
||||
HumanMessage(content="user prompt 2"),
|
||||
],
|
||||
),
|
||||
(
|
||||
MagicMock(),
|
||||
[
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="user prompt 1")],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="root_agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_text(text="root agent response")
|
||||
],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="weather_agent",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part.from_text(text="weather agent response")
|
||||
],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="test_invocation_id",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="user prompt 2")],
|
||||
),
|
||||
),
|
||||
],
|
||||
[
|
||||
SystemMessage(content="test system prompt"),
|
||||
HumanMessage(content="user prompt 2"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_langgraph_agent(
|
||||
checkpointer_value, events_list, expected_messages
|
||||
):
|
||||
mock_graph = MagicMock(spec=CompiledGraph)
|
||||
mock_graph_state = MagicMock()
|
||||
mock_graph_state.values = {}
|
||||
mock_graph.get_state.return_value = mock_graph_state
|
||||
|
||||
mock_graph.checkpointer = checkpointer_value
|
||||
mock_graph.invoke.return_value = {
|
||||
"messages": [AIMessage(content="test response")]
|
||||
}
|
||||
|
||||
mock_parent_context = MagicMock(spec=InvocationContext)
|
||||
mock_session = MagicMock()
|
||||
mock_parent_context.session = mock_session
|
||||
mock_parent_context.branch = "parent_agent"
|
||||
mock_parent_context.end_invocation = False
|
||||
mock_session.events = events_list
|
||||
mock_parent_context.invocation_id = "test_invocation_id"
|
||||
mock_parent_context.model_copy.return_value = mock_parent_context
|
||||
|
||||
weather_agent = LangGraphAgent(
|
||||
name="weather_agent",
|
||||
description="A agent that answers weather questions",
|
||||
instruction="test system prompt",
|
||||
graph=mock_graph,
|
||||
)
|
||||
|
||||
result_event = None
|
||||
async for event in weather_agent.run_async(mock_parent_context):
|
||||
result_event = event
|
||||
|
||||
assert result_event.author == "weather_agent"
|
||||
assert result_event.content.parts[0].text == "test response"
|
||||
|
||||
mock_graph.invoke.assert_called_once()
|
||||
mock_graph.invoke.assert_called_with(
|
||||
{"messages": expected_messages},
|
||||
{"configurable": {"thread_id": mock_session.id}},
|
||||
)
|
||||
138
tests/unittests/agents/test_llm_agent_callbacks.py
Normal file
138
tests/unittests/agents/test_llm_agent_callbacks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
class MockBeforeModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAfterModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'model_response'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback_end():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback',
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=MockAfterModelCallback(
|
||||
mock_response='after_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [
|
||||
('root_agent', 'after_model_callback'),
|
||||
]
|
||||
231
tests/unittests/agents/test_llm_agent_fields.py
Normal file
231
tests/unittests/agents/test_llm_agent_fields.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for canonical_xxx fields in LlmAgent."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.registry import LLMRegistry
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def _create_readonly_context(
|
||||
agent: LlmAgent, state: Optional[dict[str, Any]] = None
|
||||
) -> ReadonlyContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user', state=state
|
||||
)
|
||||
invocation_context = InvocationContext(
|
||||
invocation_id='test_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
return ReadonlyContext(invocation_context)
|
||||
|
||||
|
||||
def test_canonical_model_empty():
|
||||
agent = LlmAgent(name='test_agent')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = agent.canonical_model
|
||||
|
||||
|
||||
def test_canonical_model_str():
|
||||
agent = LlmAgent(name='test_agent', model='gemini-pro')
|
||||
|
||||
assert agent.canonical_model.model == 'gemini-pro'
|
||||
|
||||
|
||||
def test_canonical_model_llm():
|
||||
llm = LLMRegistry.new_llm('gemini-pro')
|
||||
agent = LlmAgent(name='test_agent', model=llm)
|
||||
|
||||
assert agent.canonical_model == llm
|
||||
|
||||
|
||||
def test_canonical_model_inherit():
|
||||
sub_agent = LlmAgent(name='sub_agent')
|
||||
parent_agent = LlmAgent(
|
||||
name='parent_agent', model='gemini-pro', sub_agents=[sub_agent]
|
||||
)
|
||||
|
||||
assert sub_agent.canonical_model == parent_agent.canonical_model
|
||||
|
||||
|
||||
def test_canonical_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', instruction='instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
|
||||
assert agent.canonical_instruction(ctx) == 'instruction'
|
||||
|
||||
|
||||
def test_canonical_instruction():
|
||||
def _instruction_provider(ctx: ReadonlyContext) -> str:
|
||||
return f'instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
|
||||
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
|
||||
|
||||
|
||||
def test_canonical_global_instruction_str():
|
||||
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
|
||||
ctx = _create_readonly_context(agent)
|
||||
|
||||
assert agent.canonical_global_instruction(ctx) == 'global instruction'
|
||||
|
||||
|
||||
def test_canonical_global_instruction():
|
||||
def _global_instruction_provider(ctx: ReadonlyContext) -> str:
|
||||
return f'global instruction: {ctx.state["state_var"]}'
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent', global_instruction=_global_instruction_provider
|
||||
)
|
||||
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||
|
||||
assert (
|
||||
agent.canonical_global_instruction(ctx)
|
||||
== 'global instruction: state_value'
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
|
||||
with caplog.at_level('WARNING'):
|
||||
|
||||
class Schema(BaseModel):
|
||||
pass
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
output_schema=Schema,
|
||||
)
|
||||
|
||||
# Transfer is automatically disabled
|
||||
assert agent.disallow_transfer_to_parent
|
||||
assert agent.disallow_transfer_to_peers
|
||||
assert (
|
||||
'output_schema cannot co-exist with agent transfer configurations.'
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_with_sub_agents_will_throw():
|
||||
class Schema(BaseModel):
|
||||
pass
|
||||
|
||||
sub_agent = LlmAgent(
|
||||
name='sub_agent',
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
output_schema=Schema,
|
||||
sub_agents=[sub_agent],
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_with_tools_will_throw():
|
||||
class Schema(BaseModel):
|
||||
pass
|
||||
|
||||
def _a_tool():
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
output_schema=Schema,
|
||||
tools=[_a_tool],
|
||||
)
|
||||
|
||||
|
||||
def test_before_model_callback():
|
||||
def _before_model_callback(
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent', before_model_callback=_before_model_callback
|
||||
)
|
||||
|
||||
# TODO: add more logic assertions later.
|
||||
assert agent.before_model_callback is not None
|
||||
|
||||
|
||||
def test_validate_generate_content_config_thinking_config_throw():
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
thinking_config=types.ThinkingConfig()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_generate_content_config_tools_throw():
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
tools=[types.Tool(function_declarations=[])]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_generate_content_config_system_instruction_throw():
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
system_instruction='system instruction'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_generate_content_config_response_schema_throw():
|
||||
class Schema(BaseModel):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
response_schema=Schema
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_allow_transfer_by_default():
|
||||
sub_agent = LlmAgent(name='sub_agent')
|
||||
agent = LlmAgent(name='test_agent', sub_agents=[sub_agent])
|
||||
|
||||
assert not agent.disallow_transfer_to_parent
|
||||
assert not agent.disallow_transfer_to_peers
|
||||
136
tests/unittests/agents/test_loop_agent.py
Normal file
136
tests/unittests/agents/test_loop_agent.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Testings for the SequentialAgent."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.events import Event
|
||||
from google.adk.events import EventActions
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class _TestingAgent(BaseAgent):
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, async {self.name}!')]
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, live {self.name}!')]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _TestingAgentWithEscalateAction(BaseAgent):
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, async {self.name}!')]
|
||||
),
|
||||
actions=EventActions(escalate=True),
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id=f'{test_name}_invocation_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
|
||||
loop_agent = LoopAgent(
|
||||
name=f'{request.function.__name__}_test_loop_agent',
|
||||
max_iterations=2,
|
||||
sub_agents=[
|
||||
agent,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].author == agent.name
|
||||
assert events[1].author == agent.name
|
||||
assert events[0].content.parts[0].text == f'Hello, async {agent.name}!'
|
||||
assert events[1].content.parts[0].text == f'Hello, async {agent.name}!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
|
||||
non_escalating_agent = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_non_escalating_agent'
|
||||
)
|
||||
escalating_agent = _TestingAgentWithEscalateAction(
|
||||
name=f'{request.function.__name__}_test_escalating_agent'
|
||||
)
|
||||
loop_agent = LoopAgent(
|
||||
name=f'{request.function.__name__}_test_loop_agent',
|
||||
sub_agents=[non_escalating_agent, escalating_agent],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, loop_agent
|
||||
)
|
||||
events = [e async for e in loop_agent.run_async(parent_ctx)]
|
||||
|
||||
# Only two events are generated because the sub escalating_agent escalates.
|
||||
assert len(events) == 2
|
||||
assert events[0].author == non_escalating_agent.name
|
||||
assert events[1].author == escalating_agent.name
|
||||
assert events[0].content.parts[0].text == (
|
||||
f'Hello, async {non_escalating_agent.name}!'
|
||||
)
|
||||
assert events[1].content.parts[0].text == (
|
||||
f'Hello, async {escalating_agent.name}!'
|
||||
)
|
||||
92
tests/unittests/agents/test_parallel_agent.py
Normal file
92
tests/unittests/agents/test_parallel_agent.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the ParallelAgent."""
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.parallel_agent import ParallelAgent
|
||||
from google.adk.events import Event
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class _TestingAgent(BaseAgent):
|
||||
|
||||
delay: float = 0
|
||||
"""The delay before the agent generates an event."""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
await asyncio.sleep(self.delay)
|
||||
yield Event(
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, async {self.name}!')]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id=f'{test_name}_invocation_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent1 = _TestingAgent(
|
||||
name=f'{request.function.__name__}_test_agent_1',
|
||||
delay=0.5,
|
||||
)
|
||||
agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
parallel_agent = ParallelAgent(
|
||||
name=f'{request.function.__name__}_test_parallel_agent',
|
||||
sub_agents=[
|
||||
agent1,
|
||||
agent2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, parallel_agent
|
||||
)
|
||||
events = [e async for e in parallel_agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 2
|
||||
# agent2 generates an event first, then agent1. Because they run in parallel
|
||||
# and agent1 has a delay.
|
||||
assert events[0].author == agent2.name
|
||||
assert events[1].author == agent1.name
|
||||
assert events[0].branch.endswith(agent2.name)
|
||||
assert events[1].branch.endswith(agent1.name)
|
||||
assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
|
||||
assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
|
||||
114
tests/unittests/agents/test_sequential_agent.py
Normal file
114
tests/unittests/agents/test_sequential_agent.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Testings for the SequentialAgent."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.events import Event
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class _TestingAgent(BaseAgent):
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, async {self.name}!')]
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
yield Event(
|
||||
author=self.name,
|
||||
invocation_id=ctx.invocation_id,
|
||||
content=types.Content(
|
||||
parts=[types.Part(text=f'Hello, live {self.name}!')]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _create_parent_invocation_context(
|
||||
test_name: str, agent: BaseAgent
|
||||
) -> InvocationContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id=f'{test_name}_invocation_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(request: pytest.FixtureRequest):
|
||||
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
|
||||
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
sequential_agent = SequentialAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
sub_agents=[
|
||||
agent_1,
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_async(parent_ctx)]
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].author == agent_1.name
|
||||
assert events[1].author == agent_2.name
|
||||
assert events[0].content.parts[0].text == f'Hello, async {agent_1.name}!'
|
||||
assert events[1].content.parts[0].text == f'Hello, async {agent_2.name}!'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_live(request: pytest.FixtureRequest):
|
||||
agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
|
||||
agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
|
||||
sequential_agent = SequentialAgent(
|
||||
name=f'{request.function.__name__}_test_agent',
|
||||
sub_agents=[
|
||||
agent_1,
|
||||
agent_2,
|
||||
],
|
||||
)
|
||||
parent_ctx = _create_parent_invocation_context(
|
||||
request.function.__name__, sequential_agent
|
||||
)
|
||||
events = [e async for e in sequential_agent.run_live(parent_ctx)]
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].author == agent_1.name
|
||||
assert events[1].author == agent_2.name
|
||||
assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!'
|
||||
assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!'
|
||||
14
tests/unittests/artifacts/__init__.py
Normal file
14
tests/unittests/artifacts/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
276
tests/unittests/artifacts/test_artifact_service.py
Normal file
276
tests/unittests/artifacts/test_artifact_service.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the artifact service."""
|
||||
|
||||
import enum
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.adk.artifacts import GcsArtifactService
|
||||
from google.adk.artifacts import InMemoryArtifactService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
Enum = enum.Enum
|
||||
|
||||
|
||||
class ArtifactServiceType(Enum):
|
||||
IN_MEMORY = "IN_MEMORY"
|
||||
GCS = "GCS"
|
||||
|
||||
|
||||
class MockBlob:
|
||||
"""Mocks a GCS Blob object.
|
||||
|
||||
This class provides mock implementations for a few common GCS Blob methods,
|
||||
allowing the user to test code that interacts with GCS without actually
|
||||
connecting to a real bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initializes a MockBlob.
|
||||
|
||||
Args:
|
||||
name: The name of the blob.
|
||||
"""
|
||||
self.name = name
|
||||
self.content: Optional[bytes] = None
|
||||
self.content_type: Optional[str] = None
|
||||
|
||||
def upload_from_string(
|
||||
self, data: Union[str, bytes], content_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""Mocks uploading data to the blob (from a string or bytes).
|
||||
|
||||
Args:
|
||||
data: The data to upload (string or bytes).
|
||||
content_type: The content type of the data (optional).
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
self.content = data.encode("utf-8")
|
||||
elif isinstance(data, bytes):
|
||||
self.content = data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes")
|
||||
|
||||
if content_type:
|
||||
self.content_type = content_type
|
||||
|
||||
def download_as_bytes(self) -> bytes:
|
||||
"""Mocks downloading the blob's content as bytes.
|
||||
|
||||
Returns:
|
||||
bytes: The content of the blob as bytes.
|
||||
|
||||
Raises:
|
||||
Exception: If the blob doesn't exist (hasn't been uploaded to).
|
||||
"""
|
||||
if self.content is None:
|
||||
return b""
|
||||
return self.content
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Mocks deleting a blob."""
|
||||
self.content = None
|
||||
self.content_type = None
|
||||
|
||||
|
||||
class MockBucket:
|
||||
"""Mocks a GCS Bucket object."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initializes a MockBucket.
|
||||
|
||||
Args:
|
||||
name: The name of the bucket.
|
||||
"""
|
||||
self.name = name
|
||||
self.blobs: dict[str, MockBlob] = {}
|
||||
|
||||
def blob(self, blob_name: str) -> MockBlob:
|
||||
"""Mocks getting a Blob object (doesn't create it in storage).
|
||||
|
||||
Args:
|
||||
blob_name: The name of the blob.
|
||||
|
||||
Returns:
|
||||
A MockBlob instance.
|
||||
"""
|
||||
if blob_name not in self.blobs:
|
||||
self.blobs[blob_name] = MockBlob(blob_name)
|
||||
return self.blobs[blob_name]
|
||||
|
||||
|
||||
class MockClient:
|
||||
"""Mocks the GCS Client."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes MockClient."""
|
||||
self.buckets: dict[str, MockBucket] = {}
|
||||
|
||||
def bucket(self, bucket_name: str) -> MockBucket:
|
||||
"""Mocks getting a Bucket object."""
|
||||
if bucket_name not in self.buckets:
|
||||
self.buckets[bucket_name] = MockBucket(bucket_name)
|
||||
return self.buckets[bucket_name]
|
||||
|
||||
def list_blobs(self, bucket: MockBucket, prefix: Optional[str] = None):
|
||||
"""Mocks listing blobs in a bucket, optionally with a prefix."""
|
||||
if prefix:
|
||||
return [
|
||||
blob for name, blob in bucket.blobs.items() if name.startswith(prefix)
|
||||
]
|
||||
return list(bucket.blobs.values())
|
||||
|
||||
|
||||
def mock_gcs_artifact_service():
|
||||
"""Creates a mock GCS artifact service for testing."""
|
||||
service = GcsArtifactService(bucket_name="test_bucket")
|
||||
service.storage_client = MockClient()
|
||||
service.bucket = service.storage_client.bucket("test_bucket")
|
||||
return service
|
||||
|
||||
|
||||
def get_artifact_service(
|
||||
service_type: ArtifactServiceType = ArtifactServiceType.IN_MEMORY,
|
||||
):
|
||||
"""Creates an artifact service for testing."""
|
||||
if service_type == ArtifactServiceType.GCS:
|
||||
return mock_gcs_artifact_service()
|
||||
return InMemoryArtifactService()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||
)
|
||||
def test_load_empty(service_type):
|
||||
"""Tests loading an artifact when none exists."""
|
||||
artifact_service = get_artifact_service(service_type)
|
||||
assert not artifact_service.load_artifact(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
session_id="session_id",
|
||||
filename="filename",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||
)
|
||||
def test_save_load_delete(service_type):
|
||||
"""Tests saving, loading, and deleting an artifact."""
|
||||
artifact_service = get_artifact_service(service_type)
|
||||
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
||||
app_name = "app0"
|
||||
user_id = "user0"
|
||||
session_id = "123"
|
||||
filename = "file456"
|
||||
|
||||
artifact_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
artifact=artifact,
|
||||
)
|
||||
assert (
|
||||
artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
== artifact
|
||||
)
|
||||
|
||||
artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
assert not artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||
)
|
||||
def test_list_keys(service_type):
|
||||
"""Tests listing keys in the artifact service."""
|
||||
artifact_service = get_artifact_service(service_type)
|
||||
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
||||
app_name = "app0"
|
||||
user_id = "user0"
|
||||
session_id = "123"
|
||||
filename = "filename"
|
||||
filenames = [filename + str(i) for i in range(5)]
|
||||
|
||||
for f in filenames:
|
||||
artifact_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=f,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
assert (
|
||||
artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
== filenames
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||
)
|
||||
def test_list_versions(service_type):
|
||||
"""Tests listing versions of an artifact."""
|
||||
artifact_service = get_artifact_service(service_type)
|
||||
|
||||
app_name = "app0"
|
||||
user_id = "user0"
|
||||
session_id = "123"
|
||||
filename = "filename"
|
||||
versions = [
|
||||
types.Part.from_bytes(
|
||||
data=i.to_bytes(2, byteorder="big"), mime_type="text/plain"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
for i in range(3):
|
||||
artifact_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
artifact=versions[i],
|
||||
)
|
||||
|
||||
response_versions = artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
assert response_versions == list(range(3))
|
||||
578
tests/unittests/auth/test_auth_handler.py
Normal file
578
tests/unittests/auth/test_auth_handler.py
Normal file
@@ -0,0 +1,578 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
import pytest
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_handler import AuthHandler
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
|
||||
|
||||
# Mock classes for testing
|
||||
class MockState(dict):
|
||||
"""Mock State class for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
|
||||
class MockOAuth2Session:
|
||||
"""Mock OAuth2Session for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id=None,
|
||||
client_secret=None,
|
||||
scope=None,
|
||||
redirect_uri=None,
|
||||
state=None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.scope = scope
|
||||
self.redirect_uri = redirect_uri
|
||||
self.state = state
|
||||
|
||||
def create_authorization_url(self, url):
|
||||
return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state"
|
||||
|
||||
def fetch_token(
|
||||
self,
|
||||
token_endpoint,
|
||||
authorization_response=None,
|
||||
code=None,
|
||||
grant_type=None,
|
||||
):
|
||||
return {
|
||||
"access_token": "mock_access_token",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "mock_refresh_token",
|
||||
}
|
||||
|
||||
|
||||
# Fixtures for common test objects
|
||||
@pytest.fixture
|
||||
def oauth2_auth_scheme():
|
||||
"""Create an OAuth2 auth scheme for testing."""
|
||||
# Create the OAuthFlows object first
|
||||
flows = OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl="https://example.com/oauth2/authorize",
|
||||
tokenUrl="https://example.com/oauth2/token",
|
||||
scopes={"read": "Read access", "write": "Write access"},
|
||||
)
|
||||
)
|
||||
|
||||
# Then create the OAuth2 object with the flows
|
||||
return OAuth2(flows=flows)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openid_auth_scheme():
|
||||
"""Create an OpenID Connect auth scheme for testing."""
|
||||
return OpenIdConnectWithConfig(
|
||||
openIdConnectUrl="https://example.com/.well-known/openid-configuration",
|
||||
authorization_endpoint="https://example.com/oauth2/authorize",
|
||||
token_endpoint="https://example.com/oauth2/token",
|
||||
scopes=["openid", "profile", "email"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_credentials():
|
||||
"""Create OAuth2 credentials for testing."""
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_credentials_with_token():
|
||||
"""Create OAuth2 credentials with a token for testing."""
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
token={
|
||||
"access_token": "mock_access_token",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "mock_refresh_token",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_credentials_with_auth_uri():
|
||||
"""Create OAuth2 credentials with an auth URI for testing."""
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
|
||||
state="mock_state",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_credentials_with_auth_code():
|
||||
"""Create OAuth2 credentials with an auth code for testing."""
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
|
||||
state="mock_state",
|
||||
auth_code="mock_auth_code",
|
||||
auth_response_uri="https://example.com/callback?code=mock_auth_code&state=mock_state",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_config(oauth2_auth_scheme, oauth2_credentials):
|
||||
"""Create an AuthConfig for testing."""
|
||||
# Create a copy of the credentials for the exchanged_auth_credential
|
||||
exchanged_credential = oauth2_credentials.model_copy(deep=True)
|
||||
|
||||
return AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials,
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_config_with_exchanged(
|
||||
oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_uri
|
||||
):
|
||||
"""Create an AuthConfig with exchanged credentials for testing."""
|
||||
return AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials,
|
||||
exchanged_auth_credential=oauth2_credentials_with_auth_uri,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_config_with_auth_code(
|
||||
oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_code
|
||||
):
|
||||
"""Create an AuthConfig with auth code for testing."""
|
||||
return AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials,
|
||||
exchanged_auth_credential=oauth2_credentials_with_auth_code,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthHandlerInit:
|
||||
"""Tests for the AuthHandler initialization."""
|
||||
|
||||
def test_init(self, auth_config):
|
||||
"""Test the initialization of AuthHandler."""
|
||||
handler = AuthHandler(auth_config)
|
||||
assert handler.auth_config == auth_config
|
||||
|
||||
|
||||
class TestGetCredentialKey:
|
||||
"""Tests for the get_credential_key method."""
|
||||
|
||||
def test_get_credential_key(self, auth_config):
|
||||
"""Test generating a unique credential key."""
|
||||
handler = AuthHandler(auth_config)
|
||||
key = handler.get_credential_key()
|
||||
assert key.startswith("temp:adk_oauth2_")
|
||||
assert "_oauth2_" in key
|
||||
|
||||
def test_get_credential_key_with_extras(self, auth_config):
|
||||
"""Test generating a key when model_extra exists."""
|
||||
# Add model_extra to test cleanup
|
||||
|
||||
original_key = AuthHandler(auth_config).get_credential_key()
|
||||
key = AuthHandler(auth_config).get_credential_key()
|
||||
|
||||
auth_config.auth_scheme.model_extra["extra_field"] = "value"
|
||||
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
|
||||
|
||||
assert original_key == key
|
||||
assert "extra_field" in auth_config.auth_scheme.model_extra
|
||||
assert "extra_field" in auth_config.raw_auth_credential.model_extra
|
||||
|
||||
|
||||
class TestGenerateAuthUri:
|
||||
"""Tests for the generate_auth_uri method."""
|
||||
|
||||
@pytest.mark.skip(reason="broken tests")
|
||||
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
|
||||
def test_generate_auth_uri_oauth2(self, auth_config):
|
||||
"""Test generating an auth URI for OAuth2."""
|
||||
handler = AuthHandler(auth_config)
|
||||
result = handler.generate_auth_uri()
|
||||
|
||||
assert result.oauth2.auth_uri.startswith(
|
||||
"https://example.com/oauth2/authorize"
|
||||
)
|
||||
assert "client_id=mock_client_id" in result.oauth2.auth_uri
|
||||
assert result.oauth2.state == "mock_state"
|
||||
|
||||
@pytest.mark.skip(reason="broken tests")
|
||||
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
|
||||
def test_generate_auth_uri_openid(
|
||||
self, openid_auth_scheme, oauth2_credentials
|
||||
):
|
||||
"""Test generating an auth URI for OpenID Connect."""
|
||||
# Create a copy for the exchanged credential
|
||||
exchanged = oauth2_credentials.model_copy(deep=True)
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=openid_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials,
|
||||
exchanged_auth_credential=exchanged,
|
||||
)
|
||||
handler = AuthHandler(config)
|
||||
result = handler.generate_auth_uri()
|
||||
|
||||
assert result.oauth2.auth_uri.startswith(
|
||||
"https://example.com/oauth2/authorize"
|
||||
)
|
||||
assert "client_id=mock_client_id" in result.oauth2.auth_uri
|
||||
assert result.oauth2.state == "mock_state"
|
||||
|
||||
|
||||
class TestGenerateAuthRequest:
|
||||
"""Tests for the generate_auth_request method."""
|
||||
|
||||
def test_non_oauth_scheme(self):
|
||||
"""Test with a non-OAuth auth scheme."""
|
||||
# Use a SecurityBase instance without using APIKey which has validation issues
|
||||
api_key_scheme = APIKey(**{"name": "test_api_key", "in": APIKeyIn.header})
|
||||
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
|
||||
)
|
||||
|
||||
# Create a copy for the exchanged credential
|
||||
exchanged = credential.model_copy(deep=True)
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=api_key_scheme,
|
||||
raw_auth_credential=credential,
|
||||
exchanged_auth_credential=exchanged,
|
||||
)
|
||||
|
||||
handler = AuthHandler(config)
|
||||
result = handler.generate_auth_request()
|
||||
|
||||
assert result == config
|
||||
|
||||
def test_with_existing_auth_uri(self, auth_config_with_exchanged):
|
||||
"""Test when auth_uri already exists in exchanged credential."""
|
||||
handler = AuthHandler(auth_config_with_exchanged)
|
||||
result = handler.generate_auth_request()
|
||||
|
||||
assert (
|
||||
result.exchanged_auth_credential.oauth2.auth_uri
|
||||
== auth_config_with_exchanged.exchanged_auth_credential.oauth2.auth_uri
|
||||
)
|
||||
|
||||
def test_missing_raw_credential(self, oauth2_auth_scheme):
|
||||
"""Test when raw_auth_credential is missing."""
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
)
|
||||
handler = AuthHandler(config)
|
||||
|
||||
with pytest.raises(ValueError, match="requires auth_credential"):
|
||||
handler.generate_auth_request()
|
||||
|
||||
def test_missing_oauth2_in_raw_credential(self, oauth2_auth_scheme):
|
||||
"""Test when oauth2 is missing in raw_auth_credential."""
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
|
||||
)
|
||||
|
||||
# Create a copy for the exchanged credential
|
||||
exchanged = credential.model_copy(deep=True)
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=credential,
|
||||
exchanged_auth_credential=exchanged,
|
||||
)
|
||||
handler = AuthHandler(config)
|
||||
|
||||
with pytest.raises(ValueError, match="requires oauth2 in auth_credential"):
|
||||
handler.generate_auth_request()
|
||||
|
||||
def test_auth_uri_in_raw_credential(
|
||||
self, oauth2_auth_scheme, oauth2_credentials_with_auth_uri
|
||||
):
|
||||
"""Test when auth_uri exists in raw_credential."""
|
||||
config = AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=oauth2_credentials_with_auth_uri,
|
||||
exchanged_auth_credential=oauth2_credentials_with_auth_uri.model_copy(
|
||||
deep=True
|
||||
),
|
||||
)
|
||||
handler = AuthHandler(config)
|
||||
result = handler.generate_auth_request()
|
||||
|
||||
assert (
|
||||
result.exchanged_auth_credential.oauth2.auth_uri
|
||||
== oauth2_credentials_with_auth_uri.oauth2.auth_uri
|
||||
)
|
||||
|
||||
def test_missing_client_credentials(self, oauth2_auth_scheme):
|
||||
"""Test when client_id or client_secret is missing."""
|
||||
bad_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(redirect_uri="https://example.com/callback"),
|
||||
)
|
||||
|
||||
# Create a copy for the exchanged credential
|
||||
exchanged = bad_credential.model_copy(deep=True)
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
raw_auth_credential=bad_credential,
|
||||
exchanged_auth_credential=exchanged,
|
||||
)
|
||||
handler = AuthHandler(config)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="requires both client_id and client_secret"
|
||||
):
|
||||
handler.generate_auth_request()
|
||||
|
||||
@patch("google.adk.auth.auth_handler.AuthHandler.generate_auth_uri")
|
||||
def test_generate_new_auth_uri(self, mock_generate_auth_uri, auth_config):
|
||||
"""Test generating a new auth URI."""
|
||||
mock_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="mock_client_id",
|
||||
client_secret="mock_client_secret",
|
||||
redirect_uri="https://example.com/callback",
|
||||
auth_uri="https://example.com/generated",
|
||||
state="generated_state",
|
||||
),
|
||||
)
|
||||
mock_generate_auth_uri.return_value = mock_credential
|
||||
|
||||
handler = AuthHandler(auth_config)
|
||||
result = handler.generate_auth_request()
|
||||
|
||||
assert mock_generate_auth_uri.called
|
||||
assert result.exchanged_auth_credential == mock_credential
|
||||
|
||||
|
||||
class TestGetAuthResponse:
|
||||
"""Tests for the get_auth_response method."""
|
||||
|
||||
def test_get_auth_response_exists(
|
||||
self, auth_config, oauth2_credentials_with_auth_uri
|
||||
):
|
||||
"""Test retrieving an existing auth response from state."""
|
||||
handler = AuthHandler(auth_config)
|
||||
state = MockState()
|
||||
|
||||
# Store a credential in the state
|
||||
credential_key = handler.get_credential_key()
|
||||
state[credential_key] = oauth2_credentials_with_auth_uri
|
||||
|
||||
result = handler.get_auth_response(state)
|
||||
assert result == oauth2_credentials_with_auth_uri
|
||||
|
||||
def test_get_auth_response_not_exists(self, auth_config):
|
||||
"""Test retrieving a non-existent auth response from state."""
|
||||
handler = AuthHandler(auth_config)
|
||||
state = MockState()
|
||||
|
||||
result = handler.get_auth_response(state)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestParseAndStoreAuthResponse:
|
||||
"""Tests for the parse_and_store_auth_response method."""
|
||||
|
||||
def test_non_oauth_scheme(self, auth_config_with_exchanged):
|
||||
"""Test with a non-OAuth auth scheme."""
|
||||
# Modify the auth scheme type to be non-OAuth
|
||||
auth_config = copy.deepcopy(auth_config_with_exchanged)
|
||||
auth_config.auth_scheme = APIKey(
|
||||
**{"name": "test_api_key", "in": APIKeyIn.header}
|
||||
)
|
||||
|
||||
handler = AuthHandler(auth_config)
|
||||
state = MockState()
|
||||
|
||||
handler.parse_and_store_auth_response(state)
|
||||
|
||||
credential_key = handler.get_credential_key()
|
||||
assert state[credential_key] == auth_config.exchanged_auth_credential
|
||||
|
||||
@patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
|
||||
def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
|
||||
"""Test with an OAuth auth scheme."""
|
||||
mock_exchange_token.return_value = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
|
||||
)
|
||||
|
||||
handler = AuthHandler(auth_config_with_exchanged)
|
||||
state = MockState()
|
||||
|
||||
handler.parse_and_store_auth_response(state)
|
||||
|
||||
credential_key = handler.get_credential_key()
|
||||
assert state[credential_key] == mock_exchange_token.return_value
|
||||
assert mock_exchange_token.called
|
||||
|
||||
|
||||
class TestExchangeAuthToken:
|
||||
"""Tests for the exchange_auth_token method."""
|
||||
|
||||
def test_token_exchange_not_supported(
|
||||
self, auth_config_with_auth_code, monkeypatch
|
||||
):
|
||||
"""Test when token exchange is not supported."""
|
||||
monkeypatch.setattr(
|
||||
"google.adk.auth.auth_handler.SUPPORT_TOKEN_EXCHANGE", False
|
||||
)
|
||||
|
||||
handler = AuthHandler(auth_config_with_auth_code)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == auth_config_with_auth_code.exchanged_auth_credential
|
||||
|
||||
def test_openid_missing_token_endpoint(
|
||||
self, openid_auth_scheme, oauth2_credentials_with_auth_code
|
||||
):
|
||||
"""Test OpenID Connect without a token endpoint."""
|
||||
# Create a scheme without token_endpoint
|
||||
scheme_without_token = copy.deepcopy(openid_auth_scheme)
|
||||
delattr(scheme_without_token, "token_endpoint")
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=scheme_without_token,
|
||||
raw_auth_credential=oauth2_credentials_with_auth_code,
|
||||
exchanged_auth_credential=oauth2_credentials_with_auth_code,
|
||||
)
|
||||
|
||||
handler = AuthHandler(config)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == oauth2_credentials_with_auth_code
|
||||
|
||||
def test_oauth2_missing_token_url(
|
||||
self, oauth2_auth_scheme, oauth2_credentials_with_auth_code
|
||||
):
|
||||
"""Test OAuth2 without a token URL."""
|
||||
# Create a scheme without tokenUrl
|
||||
scheme_without_token = copy.deepcopy(oauth2_auth_scheme)
|
||||
scheme_without_token.flows.authorizationCode.tokenUrl = None
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=scheme_without_token,
|
||||
raw_auth_credential=oauth2_credentials_with_auth_code,
|
||||
exchanged_auth_credential=oauth2_credentials_with_auth_code,
|
||||
)
|
||||
|
||||
handler = AuthHandler(config)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == oauth2_credentials_with_auth_code
|
||||
|
||||
def test_non_oauth_scheme(self, auth_config_with_auth_code):
|
||||
"""Test with a non-OAuth auth scheme."""
|
||||
# Modify the auth scheme type to be non-OAuth
|
||||
auth_config = copy.deepcopy(auth_config_with_auth_code)
|
||||
auth_config.auth_scheme = APIKey(
|
||||
**{"name": "test_api_key", "in": APIKeyIn.header}
|
||||
)
|
||||
|
||||
handler = AuthHandler(auth_config)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == auth_config.exchanged_auth_credential
|
||||
|
||||
def test_missing_credentials(self, oauth2_auth_scheme):
|
||||
"""Test with missing credentials."""
|
||||
empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2)
|
||||
|
||||
config = AuthConfig(
|
||||
auth_scheme=oauth2_auth_scheme,
|
||||
exchanged_auth_credential=empty_credential,
|
||||
)
|
||||
|
||||
handler = AuthHandler(config)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == empty_credential
|
||||
|
||||
def test_credentials_with_token(
|
||||
self, auth_config, oauth2_credentials_with_token
|
||||
):
|
||||
"""Test when credentials already have a token."""
|
||||
config = AuthConfig(
|
||||
auth_scheme=auth_config.auth_scheme,
|
||||
raw_auth_credential=auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=oauth2_credentials_with_token,
|
||||
)
|
||||
|
||||
handler = AuthHandler(config)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result == oauth2_credentials_with_token
|
||||
|
||||
@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
|
||||
def test_successful_token_exchange(self, auth_config_with_auth_code):
|
||||
"""Test a successful token exchange."""
|
||||
handler = AuthHandler(auth_config_with_auth_code)
|
||||
result = handler.exchange_auth_token()
|
||||
|
||||
assert result.oauth2.token["access_token"] == "mock_access_token"
|
||||
assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
|
||||
assert result.auth_type == AuthCredentialTypes.OAUTH2
|
||||
73
tests/unittests/conftest.py
Normal file
73
tests/unittests/conftest.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from pytest import fixture
|
||||
from pytest import FixtureRequest
|
||||
from pytest import hookimpl
|
||||
from pytest import Metafunc
|
||||
|
||||
_ENV_VARS = {
|
||||
'GOOGLE_API_KEY': 'fake_google_api_key',
|
||||
'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project',
|
||||
'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location',
|
||||
}
|
||||
|
||||
ENV_SETUPS = {
|
||||
'GOOGLE_AI': {
|
||||
'GOOGLE_GENAI_USE_VERTEXAI': '0',
|
||||
**_ENV_VARS,
|
||||
},
|
||||
'VERTEX': {
|
||||
'GOOGLE_GENAI_USE_VERTEXAI': '1',
|
||||
**_ENV_VARS,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def env_variables(request: FixtureRequest):
|
||||
# Set up the environment
|
||||
env_name: str = request.param
|
||||
envs = ENV_SETUPS[env_name]
|
||||
original_env = {key: os.environ.get(key) for key in envs}
|
||||
os.environ.update(envs)
|
||||
|
||||
yield # Run the test
|
||||
|
||||
# Restore the environment
|
||||
for key in envs:
|
||||
if (original_val := original_env.get(key)) is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_val
|
||||
|
||||
|
||||
@hookimpl(tryfirst=True)
|
||||
def pytest_generate_tests(metafunc: Metafunc):
|
||||
"""Generate test cases for each environment setup."""
|
||||
if env_variables.__name__ in metafunc.fixturenames:
|
||||
if not _is_explicitly_marked(env_variables.__name__, metafunc):
|
||||
metafunc.parametrize(
|
||||
env_variables.__name__, ENV_SETUPS.keys(), indirect=True
|
||||
)
|
||||
|
||||
|
||||
def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool:
|
||||
if hasattr(metafunc.function, 'pytestmark'):
|
||||
for mark in metafunc.function.pytestmark:
|
||||
if mark.name == 'parametrize' and mark.args[0] == mark_name:
|
||||
return True
|
||||
return False
|
||||
14
tests/unittests/fast_api/__init__.py
Normal file
14
tests/unittests/fast_api/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
269
tests/unittests/fast_api/test_fast_api.py
Normal file
269
tests/unittests/fast_api/test_fast_api.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types as ptypes
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.adk.agents import BaseAgent
|
||||
from google.adk.agents import LiveRequest
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.cli.fast_api import AgentRunRequest
|
||||
from google.adk.cli.fast_api import get_fast_api_app
|
||||
from google.adk.cli.utils import envs
|
||||
from google.adk.events import Event
|
||||
from google.adk.runners import Runner
|
||||
from google.genai import types
|
||||
import httpx
|
||||
import pytest
|
||||
from uvicorn.main import run as uvicorn_run
|
||||
import websockets
|
||||
|
||||
|
||||
# Here we “fake” the agent module that get_fast_api_app expects.
|
||||
# The server code does: `agent_module = importlib.import_module(agent_name)`
|
||||
# and then accesses: agent_module.agent.root_agent.
|
||||
class DummyAgent(BaseAgent):
|
||||
pass
|
||||
|
||||
|
||||
dummy_module = ptypes.ModuleType("test_agent")
|
||||
dummy_module.agent = ptypes.SimpleNamespace(
|
||||
root_agent=DummyAgent(name="dummy_agent")
|
||||
)
|
||||
sys.modules["test_app"] = dummy_module
|
||||
envs.load_dotenv_for_agent("test_app", ".")
|
||||
|
||||
event1 = Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
|
||||
),
|
||||
)
|
||||
|
||||
event2 = Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
text=None,
|
||||
inline_data=types.Blob(
|
||||
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
event3 = Event(
|
||||
author="dummy agent", invocation_id="invocation_id", interrupted=True
|
||||
)
|
||||
|
||||
|
||||
# For simplicity, we patch Runner.run_live to yield dummy events.
|
||||
# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts).
|
||||
async def dummy_run_live(
|
||||
self, session, live_request_queue
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Immediately yield a dummy event with a text reply.
|
||||
yield event1
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event2
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event3
|
||||
|
||||
raise Exception()
|
||||
|
||||
|
||||
async def dummy_run_async(
|
||||
self,
|
||||
user_id,
|
||||
session_id,
|
||||
new_message,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Immediately yield a dummy event with a text reply.
|
||||
yield event1
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event2
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event3
|
||||
|
||||
return
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Pytest fixtures to patch methods and start the server
|
||||
###############################################################################
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_runner(monkeypatch):
|
||||
# Patch the Runner methods to use our dummy implementations.
|
||||
monkeypatch.setattr(Runner, "run_live", dummy_run_live)
|
||||
monkeypatch.setattr(Runner, "run_async", dummy_run_async)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def start_server():
|
||||
"""Start the FastAPI server in a background thread."""
|
||||
|
||||
def run_server():
|
||||
uvicorn_run(
|
||||
get_fast_api_app(agent_dir=".", web=True),
|
||||
host="0.0.0.0",
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
# Wait a moment to ensure the server is up.
|
||||
time.sleep(2)
|
||||
yield
|
||||
# The daemon thread will be terminated when tests complete.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_endpoint():
|
||||
base_http_url = "http://127.0.0.1:8000"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
|
||||
# Ensure that the session exists (create if necessary).
|
||||
url_create = (
|
||||
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
|
||||
)
|
||||
httpx.post(url_create, json={"state": {}})
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Make a POST request to the SSE endpoint.
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{base_http_url}/run_sse",
|
||||
json=json.loads(
|
||||
AgentRunRequest(
|
||||
app_name="test_app",
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=types.Content(
|
||||
parts=[types.Part(text="Hello via SSE", inline_data=None)]
|
||||
),
|
||||
streaming=False,
|
||||
).model_dump_json(exclude_none=True)
|
||||
),
|
||||
) as response:
|
||||
# Ensure the status code and header are as expected.
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
response.headers.get("content-type")
|
||||
== "text/event-stream; charset=utf-8"
|
||||
)
|
||||
|
||||
# Iterate over events from the stream.
|
||||
event_count = 0
|
||||
event_buffer = ""
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
event_buffer += line + "\n"
|
||||
|
||||
# An SSE event is terminated by an empty line (double newline)
|
||||
if line == "" and event_buffer.strip():
|
||||
# Process the complete event
|
||||
event_data = None
|
||||
for event_line in event_buffer.split("\n"):
|
||||
if event_line.startswith("data: "):
|
||||
event_data = event_line[6:] # Remove "data: " prefix
|
||||
|
||||
if event_data:
|
||||
event_count += 1
|
||||
if event_count == 1:
|
||||
assert event_data == event1.model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
elif event_count == 2:
|
||||
assert event_data == event2.model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
elif event_count == 3:
|
||||
assert event_data == event3.model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Reset buffer for next event
|
||||
event_buffer = ""
|
||||
|
||||
assert event_count == 3 # Expecting 3 events from dummy_run_async
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_endpoint():
|
||||
base_http_url = "http://127.0.0.1:8000"
|
||||
base_ws_url = "ws://127.0.0.1:8000"
|
||||
user_id = "test_user"
|
||||
session_id = "test_session"
|
||||
|
||||
# Ensure that the session exists (create if necessary).
|
||||
url_create = (
|
||||
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
|
||||
)
|
||||
httpx.post(url_create, json={"state": {}})
|
||||
|
||||
ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}"
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
# --- Test sending text data ---
|
||||
text_payload = LiveRequest(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text="Hello via WebSocket", inline_data=None)]
|
||||
)
|
||||
)
|
||||
await ws.send(text_payload.model_dump_json())
|
||||
# Wait for a reply from our dummy_run_live.
|
||||
reply = await ws.recv()
|
||||
event = Event.model_validate_json(reply)
|
||||
assert event.content.parts[0].text == "LLM reply"
|
||||
|
||||
# --- Test sending binary data (allowed mime type "audio/pcm") ---
|
||||
sample_audio = b"\x00\xFF"
|
||||
binary_payload = LiveRequest(
|
||||
blob=types.Blob(
|
||||
mime_type="audio/pcm",
|
||||
data=sample_audio,
|
||||
)
|
||||
)
|
||||
await ws.send(binary_payload.model_dump_json())
|
||||
# Wait for a reply.
|
||||
reply = await ws.recv()
|
||||
event = Event.model_validate_json(reply)
|
||||
assert (
|
||||
event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000"
|
||||
)
|
||||
assert event.content.parts[0].inline_data.data == b"\x00\xFF"
|
||||
|
||||
reply = await ws.recv()
|
||||
event = Event.model_validate_json(reply)
|
||||
assert event.interrupted is True
|
||||
assert event.content is None
|
||||
14
tests/unittests/flows/__init__.py
Normal file
14
tests/unittests/flows/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
14
tests/unittests/flows/llm_flows/__init__.py
Normal file
14
tests/unittests/flows/llm_flows/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
142
tests/unittests/flows/llm_flows/_test_examples.py
Normal file
142
tests/unittests/flows/llm_flows/_test_examples.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: delete and rewrite unit tests
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.examples import BaseExampleProvider
|
||||
from google.adk.examples import Example
|
||||
from google.adk.flows.llm_flows import examples
|
||||
from google.adk.models.base_llm import LlmRequest
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_examples():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_examples():
|
||||
example_list = [
|
||||
Example(
|
||||
input=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="test1")],
|
||||
),
|
||||
output=[
|
||||
types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="response1")],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
examples=example_list,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert (
|
||||
request.config.system_instruction
|
||||
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
|
||||
" queries and model responses using the available tools.\n\nEXAMPLE"
|
||||
" 1:\nBegin example\n[user]\ntest1\n\n[model]\nresponse1\nEnd"
|
||||
" example\n\nEnd few-shot\nNow, try to follow these examples and"
|
||||
" complete the following conversation\n<EXAMPLES>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_base_example_provider():
|
||||
class TestExampleProvider(BaseExampleProvider):
|
||||
|
||||
def get_examples(self, query: str) -> list[Example]:
|
||||
if query == "test":
|
||||
return [
|
||||
Example(
|
||||
input=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="test")],
|
||||
),
|
||||
output=[
|
||||
types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="response1")],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
provider = TestExampleProvider()
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
examples=provider,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert (
|
||||
request.config.system_instruction
|
||||
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
|
||||
" queries and model responses using the available tools.\n\nEXAMPLE"
|
||||
" 1:\nBegin example\n[user]\ntest\n\n[model]\nresponse1\nEnd"
|
||||
" example\n\nEnd few-shot\nNow, try to follow these examples and"
|
||||
" complete the following conversation\n<EXAMPLES>"
|
||||
)
|
||||
311
tests/unittests/flows/llm_flows/test_agent_transfer.py
Normal file
311
tests/unittests/flows/llm_flows/test_agent_transfer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.tools import exit_loop
|
||||
from google.genai.types import Part
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def transfer_call_part(agent_name: str) -> Part:
|
||||
return Part.from_function_call(
|
||||
name='transfer_to_agent', args={'agent_name': agent_name}
|
||||
)
|
||||
|
||||
|
||||
TRANSFER_RESPONSE_PART = Part.from_function_response(
|
||||
name='transfer_to_agent', response={}
|
||||
)
|
||||
|
||||
|
||||
def test_auto_to_auto():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (auto)
|
||||
sub_agent_1 = Agent(name='sub_agent_1', model=mockModel)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', 'response1'),
|
||||
]
|
||||
|
||||
# sub_agent_1 should still be the current agent.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('sub_agent_1', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_single():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (single)
|
||||
sub_agent_1 = Agent(
|
||||
name='sub_agent_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', 'response1'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent, becaues sub_agent_1 is single.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_auto_to_single():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 transfers to sub_agent_1_1.
|
||||
transfer_call_part('sub_agent_1_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (auto) - sub_agent_1_1 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = Agent(
|
||||
name='sub_agent_1', model=mockModel, sub_agents=[sub_agent_1_1]
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', transfer_call_part('sub_agent_1_1')),
|
||||
('sub_agent_1', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
]
|
||||
|
||||
# sub_agent_1 should still be the current agent. sub_agent_1_1 is single so it should
|
||||
# not be the current agent, otherwise the conversation will be tied to
|
||||
# sub_agent_1_1 forever.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('sub_agent_1', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_sequential():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (sequential) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = SequentialAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', 'response2'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is sequential.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response3'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_sequential_to_auto():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
transfer_call_part('sub_agent_1_2_1'),
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (seq) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (auto) - sub_agent_1_2_1 (auto)
|
||||
# \ sub_agent_1_3 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2_1 = Agent(name='sub_agent_1_2_1', model=mockModel)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1_2_1],
|
||||
)
|
||||
sub_agent_1_3 = Agent(
|
||||
name='sub_agent_1_3',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = SequentialAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2, sub_agent_1_3],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', transfer_call_part('sub_agent_1_2_1')),
|
||||
('sub_agent_1_2', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_2_1', 'response2'),
|
||||
('sub_agent_1_3', 'response3'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is sequential.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response4'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_loop():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
Part.from_function_call(name='exit_loop', args={}),
|
||||
'response4',
|
||||
'response5',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (loop) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
tools=[exit_loop],
|
||||
)
|
||||
sub_agent_1 = LoopAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
# Transfers to sub_agent_1.
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
# Loops.
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', 'response2'),
|
||||
('sub_agent_1_1', 'response3'),
|
||||
# Exits.
|
||||
('sub_agent_1_2', Part.from_function_call(name='exit_loop', args={})),
|
||||
(
|
||||
'sub_agent_1_2',
|
||||
Part.from_function_response(name='exit_loop', response={}),
|
||||
),
|
||||
# root_agent summarizes.
|
||||
('root_agent', 'response4'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is loop.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response5'),
|
||||
]
|
||||
244
tests/unittests/flows/llm_flows/test_functions_long_running.py
Normal file
244
tests/unittests/flows/llm_flows/test_functions_long_running.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.long_running_tool import LongRunningFunctionTool
|
||||
from google.genai.types import Part
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_async_function():
|
||||
responses = [
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int, tool_context: ToolContext) -> int:
|
||||
nonlocal function_called
|
||||
|
||||
function_called += 1
|
||||
return {'status': 'pending'}
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[LongRunningFunctionTool(func=increase_by_one)],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 2
|
||||
# 1 item: user content
|
||||
assert mockModel.requests[0].contents == [
|
||||
utils.UserContent('test1'),
|
||||
]
|
||||
increase_by_one_call = Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
pending_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'pending'}
|
||||
)
|
||||
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', pending_response),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(events) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'pending'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
assert events[0].long_running_tool_ids
|
||||
|
||||
# Updates with another pending progress.
|
||||
still_waiting_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'still waiting'}
|
||||
)
|
||||
events = runner.run(utils.UserContent(still_waiting_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 3
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', still_waiting_response),
|
||||
]
|
||||
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response2')]
|
||||
|
||||
# Calls when the result is ready.
|
||||
result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
events = runner.run(utils.UserContent(result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 4
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response3')]
|
||||
|
||||
# Calls when the result is ready. Here we still accept the result and do
|
||||
# another summarization. Whether this is the right behavior is TBD.
|
||||
another_result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 3}
|
||||
)
|
||||
events = runner.run(utils.UserContent(another_result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 5
|
||||
assert utils.simplify_contents(mockModel.requests[4].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', another_result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response4')]
|
||||
|
||||
# At the end, function_called should still be 1.
|
||||
assert function_called == 1
|
||||
|
||||
|
||||
def test_async_function_with_none_response():
|
||||
responses = [
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int, tool_context: ToolContext) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return 'pending'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[LongRunningFunctionTool(func=increase_by_one)],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 2
|
||||
# 1 item: user content
|
||||
assert mockModel.requests[0].contents == [
|
||||
utils.UserContent('test1'),
|
||||
]
|
||||
increase_by_one_call = Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
(
|
||||
'user',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 'pending'}
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(events) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 'pending'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Updates with another pending progress.
|
||||
still_waiting_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'still waiting'}
|
||||
)
|
||||
events = runner.run(utils.UserContent(still_waiting_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 3
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', still_waiting_response),
|
||||
]
|
||||
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response2')]
|
||||
|
||||
# Calls when the result is ready.
|
||||
result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
events = runner.run(utils.UserContent(result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 4
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response3')]
|
||||
|
||||
# Calls when the result is ready. Here we still accept the result and do
|
||||
# another summarization. Whether this is the right behavior is TBD.
|
||||
another_result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 3}
|
||||
)
|
||||
events = runner.run(utils.UserContent(another_result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 5
|
||||
assert utils.simplify_contents(mockModel.requests[4].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', another_result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response4')]
|
||||
|
||||
# At the end, function_called should still be 1.
|
||||
assert function_called == 1
|
||||
346
tests/unittests/flows/llm_flows/test_functions_request_euc.py
Normal file
346
tests/unittests/flows/llm_flows/test_functions_request_euc.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.auth import AuthConfig
|
||||
from google.adk.auth import AuthCredential
|
||||
from google.adk.auth import AuthCredentialTypes
|
||||
from google.adk.auth import OAuth2Auth
|
||||
from google.adk.flows.llm_flows import functions
|
||||
from google.adk.tools import AuthToolArguments
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def function_call(function_call_id, name, args: dict[str, Any]) -> types.Part:
|
||||
part = types.Part.from_function_call(name=name, args=args)
|
||||
part.function_call.id = function_call_id
|
||||
return part
|
||||
|
||||
|
||||
def test_function_request_euc():
|
||||
responses = [
|
||||
[
|
||||
types.Part.from_function_call(name='call_external_api1', args={}),
|
||||
types.Part.from_function_call(name='call_external_api2', args={}),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response1'),
|
||||
],
|
||||
]
|
||||
|
||||
auth_config1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_config2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def call_external_api1(tool_context: ToolContext) -> Optional[int]:
|
||||
tool_context.request_credential(auth_config1)
|
||||
|
||||
def call_external_api2(tool_context: ToolContext) -> Optional[int]:
|
||||
tool_context.request_credential(auth_config2)
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[call_external_api1, call_external_api2],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test')
|
||||
assert events[0].content.parts[0].function_call is not None
|
||||
assert events[0].content.parts[1].function_call is not None
|
||||
auth_configs = list(events[2].actions.requested_auth_configs.values())
|
||||
exchanged_auth_config1 = auth_configs[0]
|
||||
exchanged_auth_config2 = auth_configs[1]
|
||||
assert exchanged_auth_config1.auth_scheme == auth_config1.auth_scheme
|
||||
assert (
|
||||
exchanged_auth_config1.raw_auth_credential
|
||||
== auth_config1.raw_auth_credential
|
||||
)
|
||||
assert (
|
||||
exchanged_auth_config1.exchanged_auth_credential.oauth2.auth_uri
|
||||
is not None
|
||||
)
|
||||
assert exchanged_auth_config2.auth_scheme == auth_config2.auth_scheme
|
||||
assert (
|
||||
exchanged_auth_config2.raw_auth_credential
|
||||
== auth_config2.raw_auth_credential
|
||||
)
|
||||
assert (
|
||||
exchanged_auth_config2.exchanged_auth_credential.oauth2.auth_uri
|
||||
is not None
|
||||
)
|
||||
function_call_ids = list(events[2].actions.requested_auth_configs.keys())
|
||||
|
||||
for idx, part in enumerate(events[1].content.parts):
|
||||
reqeust_euc_function_call = part.function_call
|
||||
assert reqeust_euc_function_call is not None
|
||||
assert (
|
||||
reqeust_euc_function_call.name
|
||||
== functions.REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
)
|
||||
args = AuthToolArguments.model_validate(reqeust_euc_function_call.args)
|
||||
|
||||
assert args.function_call_id == function_call_ids[idx]
|
||||
args.auth_config.auth_scheme.model_extra.clear()
|
||||
assert args.auth_config.auth_scheme == auth_configs[idx].auth_scheme
|
||||
assert (
|
||||
args.auth_config.raw_auth_credential
|
||||
== auth_configs[idx].raw_auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_function_get_auth_response():
|
||||
id_1 = 'id_1'
|
||||
id_2 = 'id_2'
|
||||
responses = [
|
||||
[
|
||||
function_call(id_1, 'call_external_api1', {}),
|
||||
function_call(id_2, 'call_external_api2', {}),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response1'),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response2'),
|
||||
],
|
||||
]
|
||||
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
function_invoked = 0
|
||||
|
||||
auth_config1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_config2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
auth_response1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
token={'access_token': 'token1'},
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_response2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
token={'access_token': 'token2'},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def call_external_api1(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config1)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config1)
|
||||
return
|
||||
assert auth_response == auth_response1.exchanged_auth_credential
|
||||
return 1
|
||||
|
||||
def call_external_api2(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config2)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config2)
|
||||
return
|
||||
assert auth_response == auth_response2.exchanged_auth_credential
|
||||
return 2
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[call_external_api1, call_external_api2],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
runner.run('test')
|
||||
request_euc_function_call_event = runner.session.events[-3]
|
||||
function_response1 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[0].function_call.name,
|
||||
response=auth_response1.model_dump(),
|
||||
)
|
||||
function_response1.id = request_euc_function_call_event.content.parts[
|
||||
0
|
||||
].function_call.id
|
||||
|
||||
function_response2 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[1].function_call.name,
|
||||
response=auth_response2.model_dump(),
|
||||
)
|
||||
function_response2.id = request_euc_function_call_event.content.parts[
|
||||
1
|
||||
].function_call.id
|
||||
runner.run(
|
||||
new_message=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert function_invoked == 4
|
||||
reqeust = mock_model.requests[-1]
|
||||
content = reqeust.contents[-1]
|
||||
parts = content.parts
|
||||
assert len(parts) == 2
|
||||
assert parts[0].function_response.name == 'call_external_api1'
|
||||
assert parts[0].function_response.response == {'result': 1}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': 2}
|
||||
93
tests/unittests/flows/llm_flows/test_functions_sequential.py
Normal file
93
tests/unittests/flows/llm_flows/test_functions_sequential.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def function_call(args: dict[str, Any]) -> types.Part:
|
||||
return types.Part.from_function_call(name='increase_by_one', args=args)
|
||||
|
||||
|
||||
def function_response(response: dict[str, Any]) -> types.Part:
|
||||
return types.Part.from_function_response(
|
||||
name='increase_by_one', response=response
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_calls():
|
||||
responses = [
|
||||
function_call({'x': 1}),
|
||||
function_call({'x': 2}),
|
||||
function_call({'x': 3}),
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mockModel, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
result = utils.simplify_events(runner.run('test'))
|
||||
assert result == [
|
||||
('root_agent', function_call({'x': 1})),
|
||||
('root_agent', function_response({'result': 2})),
|
||||
('root_agent', function_call({'x': 2})),
|
||||
('root_agent', function_response({'result': 3})),
|
||||
('root_agent', function_call({'x': 3})),
|
||||
('root_agent', function_response({'result': 4})),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 4
|
||||
# 1 item: user content
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
# 3 items: user content, functaion call / response for the 1st call
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
]
|
||||
# 5 items: user content, functaion call / response for two calls
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
('model', function_call({'x': 2})),
|
||||
('user', function_response({'result': 3})),
|
||||
]
|
||||
# 7 items: user content, functaion call / response for three calls
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
('model', function_call({'x': 2})),
|
||||
('user', function_response({'result': 3})),
|
||||
('model', function_call({'x': 3})),
|
||||
('user', function_response({'result': 4})),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
258
tests/unittests/flows/llm_flows/test_functions_simple.py
Normal file
258
tests/unittests/flows/llm_flows/test_functions_simple.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_simple_function():
|
||||
function_call_1 = types.Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
function_respones_2 = types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
responses: list[types.Content] = [
|
||||
function_call_1,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', function_call_1),
|
||||
('root_agent', function_respones_2),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call_1),
|
||||
('user', function_respones_2),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
|
||||
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
|
||||
]
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two', response={'result': 4}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two_sync', response={'result': 6}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
async def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
async def multiple_by_two(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
def multiple_by_two_sync(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[increase_by_one, multiple_by_two, multiple_by_two_sync],
|
||||
)
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
assert utils.simplify_events(events) == [
|
||||
('root_agent', function_calls),
|
||||
('root_agent', function_responses),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_calls),
|
||||
('user', function_responses),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool():
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
|
||||
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
|
||||
]
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two', response={'result': 4}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two_sync', response={'result': 6}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
async def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
async def multiple_by_two(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
def multiple_by_two_sync(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
class TestTool(FunctionTool):
|
||||
|
||||
def __init__(self, func: Callable[..., Any]):
|
||||
super().__init__(func=func)
|
||||
|
||||
wrapped_increase_by_one = TestTool(func=increase_by_one)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[wrapped_increase_by_one, multiple_by_two, multiple_by_two_sync],
|
||||
)
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
assert utils.simplify_events(events) == [
|
||||
('root_agent', function_calls),
|
||||
('root_agent', function_responses),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_calls),
|
||||
('user', function_responses),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
|
||||
|
||||
def test_update_state():
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
types.Part.from_function_call(name='update_state', args={}),
|
||||
'response1',
|
||||
]
|
||||
)
|
||||
|
||||
def update_state(tool_context: ToolContext):
|
||||
tool_context.state['x'] = 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[update_state])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
runner.run('test')
|
||||
assert runner.session.state['x'] == 1
|
||||
|
||||
|
||||
def test_function_call_id():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test')
|
||||
for reqeust in mock_model.requests:
|
||||
for content in reqeust.contents:
|
||||
for part in content.parts:
|
||||
if part.function_call:
|
||||
assert part.function_call.id is None
|
||||
if part.function_response:
|
||||
assert part.function_response.id is None
|
||||
assert events[0].content.parts[0].function_call.id.startswith('adk-')
|
||||
assert events[1].content.parts[0].function_response.id.startswith('adk-')
|
||||
66
tests/unittests/flows/llm_flows/test_identity.py
Normal file
66
tests/unittests/flows/llm_flows/test_identity.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.flows.llm_flows import identity
|
||||
from google.adk.models import LlmRequest
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_description():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent")
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""You are an agent. Your internal name is "agent"."""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_description():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
description="test description",
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == "\n\n".join([
|
||||
'You are an agent. Your internal name is "agent".',
|
||||
' The description about you is "test description"',
|
||||
])
|
||||
164
tests/unittests/flows/llm_flows/test_instructions.py
Normal file
164
tests/unittests/flows/llm_flows/test_instructions.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.flows.llm_flows import instructions
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_system_instruction():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=("""Use the echo_info tool to echo { customerId }, \
|
||||
{{customer_int }, { non-identifier-float}}, \
|
||||
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""Use the echo_info tool to echo 1234567890, 30, \
|
||||
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_system_instruction():
|
||||
def build_function_instruction(readonly_context: ReadonlyContext) -> str:
|
||||
return (
|
||||
"This is the function agent instruction for invocation:"
|
||||
f" {readonly_context.invocation_id}."
|
||||
)
|
||||
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=build_function_instruction,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"This is the function agent instruction for invocation: test_id."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_system_instruction():
|
||||
sub_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="sub_agent",
|
||||
instruction="This is the sub agent instruction.",
|
||||
)
|
||||
root_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="root_agent",
|
||||
global_instruction="This is the global instruction.",
|
||||
sub_agents=[sub_agent],
|
||||
)
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"This is the global instruction.\n\nThis is the sub agent instruction."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_system_instruction_with_namespace():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=(
|
||||
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
|
||||
),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={
|
||||
"customerId": "1234567890",
|
||||
"app:key": "app_value",
|
||||
"user:key": "user_value",
|
||||
},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
|
||||
)
|
||||
142
tests/unittests/flows/llm_flows/test_model_callbacks.py
Normal file
142
tests/unittests/flows/llm_flows/test_model_callbacks.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
class MockBeforeModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAfterModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
def test_before_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'model_response'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_model_callback_end():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback',
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=MockAfterModelCallback(
|
||||
mock_response='after_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'after_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [('root_agent', 'model_response')]
|
||||
46
tests/unittests/flows/llm_flows/test_other_configs.py
Normal file
46
tests/unittests/flows/llm_flows/test_other_configs.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_output_schema():
|
||||
class CustomOutput(BaseModel):
|
||||
custom_field: str
|
||||
|
||||
response = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
output_schema=CustomOutput,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
assert len(mockModel.requests) == 1
|
||||
assert mockModel.requests[0].config.response_schema == CustomOutput
|
||||
assert mockModel.requests[0].config.response_mime_type == 'application/json'
|
||||
269
tests/unittests/flows/llm_flows/test_tool_callbacks.py
Normal file
269
tests/unittests/flows/llm_flows/test_tool_callbacks.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import BaseTool
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
|
||||
class MockBeforeToolCallback(BaseModel):
|
||||
mock_response: dict[str, object]
|
||||
modify_tool_request: bool = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> dict[str, object]:
|
||||
if self.modify_tool_request:
|
||||
args['input_str'] = 'modified_input'
|
||||
return None
|
||||
return self.mock_response
|
||||
|
||||
|
||||
class MockAfterToolCallback(BaseModel):
|
||||
mock_response: dict[str, object]
|
||||
modify_tool_request: bool = False
|
||||
modify_tool_response: bool = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: dict[str, Any] = None,
|
||||
) -> dict[str, object]:
|
||||
if self.modify_tool_request:
|
||||
args['input_str'] = 'modified_input'
|
||||
return None
|
||||
if self.modify_tool_response:
|
||||
tool_response['result'] = 'modified_output'
|
||||
return tool_response
|
||||
return self.mock_response
|
||||
|
||||
|
||||
def noop_callback(
|
||||
**kwargs,
|
||||
) -> dict[str, object]:
|
||||
pass
|
||||
|
||||
|
||||
def test_before_tool_callback():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='simple_function', args={}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=MockBeforeToolCallback(
|
||||
mock_response={'test': 'before_tool_callback'}
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', Part.from_function_call(name='simple_function', args={})),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function', response={'test': 'before_tool_callback'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_tool_callback_noop():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=noop_callback,
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'simple_function_call'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_tool_callback_modify_tool_request():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='simple_function', args={}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=MockBeforeToolCallback(
|
||||
mock_response={'test': 'before_tool_callback'},
|
||||
modify_tool_request=True,
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', Part.from_function_call(name='simple_function', args={})),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'modified_input'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=MockAfterToolCallback(
|
||||
mock_response={'test': 'after_tool_callback'}
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function', response={'test': 'after_tool_callback'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback_noop():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=noop_callback,
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'simple_function_call'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback_modify_tool_response():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=MockAfterToolCallback(
|
||||
mock_response={'result': 'after_tool_callback'},
|
||||
modify_tool_response=True,
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'modified_output'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
14
tests/unittests/models/__init__.py
Normal file
14
tests/unittests/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
224
tests/unittests/models/test_google_llm.py
Normal file
224
tests/unittests/models/test_google_llm.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
from google.adk import version
|
||||
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
|
||||
from google.adk.models.google_llm import Gemini
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.genai import types
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate_content_response():
|
||||
return types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part.from_text(text="Hello, how can I help you?")],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_llm():
|
||||
return Gemini(model="gemini-1.5-flash")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_request():
|
||||
return LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
response_modalities=[types.Modality.TEXT],
|
||||
system_instruction="You are a helpful assistant",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_supported_models():
|
||||
models = Gemini.supported_models()
|
||||
assert len(models) == 3
|
||||
assert models[0] == r"gemini-.*"
|
||||
assert models[1] == r"projects\/.+\/locations\/.+\/endpoints\/.+"
|
||||
assert (
|
||||
models[2]
|
||||
== r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+"
|
||||
)
|
||||
|
||||
|
||||
def test_client_version_header():
|
||||
model = Gemini(model="gemini-1.5-flash")
|
||||
client = model.api_client
|
||||
expected_header = (
|
||||
f"google-adk/{version.__version__}"
|
||||
f" gl-python/{sys.version.split()[0]} google-genai-sdk/"
|
||||
)
|
||||
assert (
|
||||
expected_header
|
||||
in client._api_client._http_options.headers["x-goog-api-client"]
|
||||
)
|
||||
assert (
|
||||
expected_header in client._api_client._http_options.headers["user-agent"]
|
||||
)
|
||||
|
||||
|
||||
def test_maybe_append_user_content(gemini_llm, llm_request):
|
||||
# Test with user content already present
|
||||
gemini_llm._maybe_append_user_content(llm_request)
|
||||
assert len(llm_request.contents) == 1
|
||||
|
||||
# Test with model content as the last message
|
||||
llm_request.contents.append(
|
||||
Content(role="model", parts=[Part.from_text(text="Response")])
|
||||
)
|
||||
gemini_llm._maybe_append_user_content(llm_request)
|
||||
assert len(llm_request.contents) == 3
|
||||
assert llm_request.contents[-1].role == "user"
|
||||
assert "Continue processing" in llm_request.contents[-1].parts[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(
|
||||
gemini_llm, llm_request, generate_content_response
|
||||
):
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
# Create a mock coroutine that returns the generate_content_response
|
||||
async def mock_coro():
|
||||
return generate_content_response
|
||||
|
||||
# Assign the coroutine to the mocked method
|
||||
mock_client.aio.models.generate_content.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=False
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 1
|
||||
assert isinstance(responses[0], LlmResponse)
|
||||
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream(gemini_llm, llm_request):
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
# Create mock stream responses
|
||||
class MockAsyncIterator:
|
||||
|
||||
def __init__(self, seq):
|
||||
self.iter = iter(seq)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
mock_responses = [
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model", parts=[Part.from_text(text="Hello")]
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
),
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model", parts=[Part.from_text(text=", how")]
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
),
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part.from_text(text=" can I help you?")],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
# Create a mock coroutine that returns the MockAsyncIterator
|
||||
async def mock_coro():
|
||||
return MockAsyncIterator(mock_responses)
|
||||
|
||||
# Set the mock to return the coroutine
|
||||
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
]
|
||||
|
||||
# Assertions remain the same
|
||||
assert len(responses) == 4
|
||||
assert responses[0].partial is True
|
||||
assert responses[1].partial is True
|
||||
assert responses[2].partial is True
|
||||
assert responses[3].content.parts[0].text == "Hello, how can I help you?"
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(gemini_llm, llm_request):
|
||||
# Create a mock connection
|
||||
mock_connection = mock.MagicMock(spec=GeminiLlmConnection)
|
||||
|
||||
# Create a mock context manager
|
||||
class MockContextManager:
|
||||
|
||||
async def __aenter__(self):
|
||||
return mock_connection
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
# Mock the connect method at the class level
|
||||
with mock.patch(
|
||||
"google.adk.models.google_llm.Gemini.connect",
|
||||
return_value=MockContextManager(),
|
||||
):
|
||||
async with gemini_llm.connect(llm_request) as connection:
|
||||
assert connection is mock_connection
|
||||
804
tests/unittests/models/test_litellm.py
Normal file
804
tests/unittests/models/test_litellm.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
from google.adk.models.lite_llm import _content_to_message_param
|
||||
from google.adk.models.lite_llm import _function_declaration_to_tool_param
|
||||
from google.adk.models.lite_llm import _get_content
|
||||
from google.adk.models.lite_llm import _message_to_generate_content_response
|
||||
from google.adk.models.lite_llm import _model_response_to_chunk
|
||||
from google.adk.models.lite_llm import _to_litellm_role
|
||||
from google.adk.models.lite_llm import FunctionChunk
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.models.lite_llm import LiteLLMClient
|
||||
from google.adk.models.lite_llm import TextChunk
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.genai import types
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
from litellm import ChatCompletionMessageToolCall
|
||||
from litellm import Function
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Choices
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import StreamingChoices
|
||||
import pytest
|
||||
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
)
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name="test_function",
|
||||
description="Test function description",
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"test_arg": types.Schema(
|
||||
type=types.Type.STRING
|
||||
),
|
||||
"array_arg": types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items={
|
||||
"type": types.Type.STRING,
|
||||
},
|
||||
),
|
||||
"nested_arg": types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"nested_key1": types.Schema(
|
||||
type=types.Type.STRING
|
||||
),
|
||||
"nested_key2": types.Schema(
|
||||
type=types.Type.STRING
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
STREAMING_MODEL_RESPONSE = [
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
content="zero, ",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
content="one, ",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
content="two:",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
type="function",
|
||||
id="test_tool_call_id",
|
||||
function=Function(
|
||||
name="test_function",
|
||||
arguments='{"test_arg": "test_',
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
type="function",
|
||||
id=None,
|
||||
function=Function(
|
||||
name=None,
|
||||
arguments='value"}',
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
return ModelResponse(
|
||||
choices=[
|
||||
Choices(
|
||||
message=ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="Test response",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id="test_tool_call_id",
|
||||
function=Function(
|
||||
name="test_function",
|
||||
arguments='{"test_arg": "test_value"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_acompletion(mock_response):
|
||||
return AsyncMock(return_value=mock_response)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion(mock_response):
|
||||
return Mock(return_value=mock_response)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(mock_acompletion, mock_completion):
|
||||
return MockLLMClient(mock_acompletion, mock_completion)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lite_llm_instance(mock_client):
|
||||
return LiteLlm(model="test_model", llm_client=mock_client)
|
||||
|
||||
|
||||
class MockLLMClient(LiteLLMClient):
|
||||
|
||||
def __init__(self, acompletion_mock, completion_mock):
|
||||
self.acompletion_mock = acompletion_mock
|
||||
self.completion_mock = completion_mock
|
||||
|
||||
async def acompletion(self, model, messages, tools, **kwargs):
|
||||
return await self.acompletion_mock(
|
||||
model=model, messages=messages, tools=tools, **kwargs
|
||||
)
|
||||
|
||||
def completion(self, model, messages, tools, stream, **kwargs):
|
||||
return self.completion_mock(
|
||||
model=model, messages=messages, tools=tools, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
|
||||
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||
):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.content.parts[1].function_call.name == "test_function"
|
||||
assert response.content.parts[1].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
function_declaration_test_cases = [
|
||||
(
|
||||
"simple_function",
|
||||
types.FunctionDeclaration(
|
||||
name="test_function",
|
||||
description="Test function description",
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"test_arg": types.Schema(type=types.Type.STRING),
|
||||
"array_arg": types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
),
|
||||
"nested_arg": types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"nested_key1": types.Schema(type=types.Type.STRING),
|
||||
"nested_key2": types.Schema(type=types.Type.STRING),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function",
|
||||
"description": "Test function description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"test_arg": {"type": "string"},
|
||||
"array_arg": {
|
||||
"items": {"type": "string"},
|
||||
"type": "array",
|
||||
},
|
||||
"nested_arg": {
|
||||
"properties": {
|
||||
"nested_key1": {"type": "string"},
|
||||
"nested_key2": {"type": "string"},
|
||||
},
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"no_description",
|
||||
types.FunctionDeclaration(
|
||||
name="test_function_no_description",
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"test_arg": types.Schema(type=types.Type.STRING),
|
||||
},
|
||||
),
|
||||
),
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function_no_description",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"test_arg": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"empty_parameters",
|
||||
types.FunctionDeclaration(
|
||||
name="test_function_empty_params",
|
||||
parameters=types.Schema(type=types.Type.OBJECT, properties={}),
|
||||
),
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function_empty_params",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
"nested_array",
|
||||
types.FunctionDeclaration(
|
||||
name="test_function_nested_array",
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"array_arg": types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"nested_key": types.Schema(
|
||||
type=types.Type.STRING
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_function_nested_array",
|
||||
"description": "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"array_arg": {
|
||||
"items": {
|
||||
"properties": {
|
||||
"nested_key": {"type": "string"}
|
||||
},
|
||||
"type": "object",
|
||||
},
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"_, function_declaration, expected_output",
|
||||
function_declaration_test_cases,
|
||||
ids=[case[0] for case in function_declaration_test_cases],
|
||||
)
|
||||
def test_function_declaration_to_tool_param(
|
||||
_, function_declaration, expected_output
|
||||
):
|
||||
assert (
|
||||
_function_declaration_to_tool_param(function_declaration)
|
||||
== expected_output
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_system_instruction(
|
||||
lite_llm_instance, mock_acompletion
|
||||
):
|
||||
mock_response_with_system_instruction = ModelResponse(
|
||||
choices=[
|
||||
Choices(
|
||||
message=ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="Test response",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_acompletion.return_value = mock_response_with_system_instruction
|
||||
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
)
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="Test system instruction"
|
||||
),
|
||||
)
|
||||
|
||||
async for response in lite_llm_instance.generate_content_async(llm_request):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "developer"
|
||||
assert kwargs["messages"][0]["content"] == "Test system instruction"
|
||||
assert kwargs["messages"][1]["role"] == "user"
|
||||
assert kwargs["messages"][1]["content"] == "Test prompt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_tool_response(
|
||||
lite_llm_instance, mock_acompletion
|
||||
):
|
||||
mock_response_with_tool_response = ModelResponse(
|
||||
choices=[
|
||||
Choices(
|
||||
message=ChatCompletionAssistantMessage(
|
||||
role="tool",
|
||||
content='{"result": "test_result"}',
|
||||
tool_call_id="test_tool_call_id",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_acompletion.return_value = mock_response_with_tool_response
|
||||
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
),
|
||||
types.Content(
|
||||
role="tool",
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name="test_function",
|
||||
response={"result": "test_result"},
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="test instruction",
|
||||
),
|
||||
)
|
||||
async for response in lite_llm_instance.generate_content_async(llm_request):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == '{"result": "test_result"}'
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
|
||||
assert kwargs["messages"][2]["role"] == "tool"
|
||||
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
||||
|
||||
|
||||
def test_content_to_message_param_user_message():
|
||||
content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
)
|
||||
message = _content_to_message_param(content)
|
||||
assert message["role"] == "user"
|
||||
assert message["content"] == "Test prompt"
|
||||
|
||||
|
||||
def test_content_to_message_param_assistant_message():
|
||||
content = types.Content(
|
||||
role="assistant", parts=[types.Part.from_text(text="Test response")]
|
||||
)
|
||||
message = _content_to_message_param(content)
|
||||
assert message["role"] == "assistant"
|
||||
assert message["content"] == "Test response"
|
||||
|
||||
|
||||
def test_content_to_message_param_function_call():
|
||||
content = types.Content(
|
||||
role="assistant",
|
||||
parts=[
|
||||
types.Part.from_function_call(
|
||||
name="test_function", args={"test_arg": "test_value"}
|
||||
)
|
||||
],
|
||||
)
|
||||
content.parts[0].function_call.id = "test_tool_call_id"
|
||||
message = _content_to_message_param(content)
|
||||
assert message["role"] == "assistant"
|
||||
assert message["content"] == []
|
||||
assert message["tool_calls"][0].type == "function"
|
||||
assert message["tool_calls"][0].id == "test_tool_call_id"
|
||||
assert message["tool_calls"][0].function.name == "test_function"
|
||||
assert (
|
||||
message["tool_calls"][0].function.arguments
|
||||
== '{"test_arg": "test_value"}'
|
||||
)
|
||||
|
||||
|
||||
def test_message_to_generate_content_response_text():
|
||||
message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="Test response",
|
||||
)
|
||||
response = _message_to_generate_content_response(message)
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
|
||||
|
||||
def test_message_to_generate_content_response_tool_call():
|
||||
message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id="test_tool_call_id",
|
||||
function=Function(
|
||||
name="test_function",
|
||||
arguments='{"test_arg": "test_value"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
response = _message_to_generate_content_response(message)
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].function_call.name == "test_function"
|
||||
assert response.content.parts[0].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert response.content.parts[0].function_call.id == "test_tool_call_id"
|
||||
|
||||
|
||||
def test_get_content_text():
|
||||
parts = [types.Part.from_text(text="Test text")]
|
||||
content = _get_content(parts)
|
||||
assert content == "Test text"
|
||||
|
||||
|
||||
def test_get_content_image():
|
||||
parts = [
|
||||
types.Part.from_bytes(data=b"test_image_data", mime_type="image/png")
|
||||
]
|
||||
content = _get_content(parts)
|
||||
assert content[0]["type"] == "image_url"
|
||||
assert content[0]["image_url"] == "data:image/png;base64,dGVzdF9pbWFnZV9kYXRh"
|
||||
|
||||
|
||||
def test_get_content_video():
|
||||
parts = [
|
||||
types.Part.from_bytes(data=b"test_video_data", mime_type="video/mp4")
|
||||
]
|
||||
content = _get_content(parts)
|
||||
assert content[0]["type"] == "video_url"
|
||||
assert content[0]["video_url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
|
||||
|
||||
|
||||
def test_to_litellm_role():
|
||||
assert _to_litellm_role("model") == "assistant"
|
||||
assert _to_litellm_role("assistant") == "assistant"
|
||||
assert _to_litellm_role("user") == "user"
|
||||
assert _to_litellm_role(None) == "user"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected_chunk, expected_finished",
|
||||
[
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[
|
||||
{
|
||||
"message": {
|
||||
"content": "this is a test",
|
||||
}
|
||||
}
|
||||
]
|
||||
),
|
||||
TextChunk(text="this is a test"),
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
delta=Delta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
type="function",
|
||||
id="1",
|
||||
function=Function(
|
||||
name="test_function",
|
||||
arguments='{"key": "va',
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
||||
None,
|
||||
"tool_calls",
|
||||
),
|
||||
(ModelResponse(choices=[{}]), None, "stop"),
|
||||
],
|
||||
)
|
||||
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
|
||||
result = list(_model_response_to_chunk(response))
|
||||
assert len(result) == 1
|
||||
chunk, finished = result[0]
|
||||
if expected_chunk:
|
||||
assert isinstance(chunk, type(expected_chunk))
|
||||
assert chunk == expected_chunk
|
||||
else:
|
||||
assert chunk is None
|
||||
assert finished == expected_finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||
lite_llm_instance = LiteLlm(
|
||||
# valid args
|
||||
model="test_model",
|
||||
llm_client=mock_client,
|
||||
api_key="test_key",
|
||||
api_base="some://url",
|
||||
api_version="2024-09-12",
|
||||
# invalid args (ignored)
|
||||
stream=True,
|
||||
messages=[{"role": "invalid", "content": "invalid"}],
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "invalid",
|
||||
},
|
||||
}],
|
||||
)
|
||||
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||
):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.content.parts[1].function_call.name == "test_function"
|
||||
assert response.content.parts[1].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert "stream" not in kwargs
|
||||
assert "llm_client" not in kwargs
|
||||
assert kwargs["api_base"] == "some://url"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_additional_args(mock_completion, mock_client):
|
||||
lite_llm_instance = LiteLlm(
|
||||
# valid args
|
||||
model="test_model",
|
||||
llm_client=mock_client,
|
||||
api_key="test_key",
|
||||
api_base="some://url",
|
||||
api_version="2024-09-12",
|
||||
# invalid args (ignored)
|
||||
stream=False,
|
||||
messages=[{"role": "invalid", "content": "invalid"}],
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "invalid",
|
||||
},
|
||||
}],
|
||||
)
|
||||
|
||||
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
|
||||
|
||||
responses = [
|
||||
response
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
||||
)
|
||||
]
|
||||
assert len(responses) == 4
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert kwargs["stream"]
|
||||
assert "llm_client" not in kwargs
|
||||
assert kwargs["api_base"] == "some://url"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream(
|
||||
mock_completion, lite_llm_instance
|
||||
):
|
||||
|
||||
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
|
||||
|
||||
responses = [
|
||||
response
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
||||
)
|
||||
]
|
||||
assert len(responses) == 4
|
||||
assert responses[0].content.role == "model"
|
||||
assert responses[0].content.parts[0].text == "zero, "
|
||||
assert responses[1].content.role == "model"
|
||||
assert responses[1].content.parts[0].text == "one, "
|
||||
assert responses[2].content.role == "model"
|
||||
assert responses[2].content.parts[0].text == "two:"
|
||||
assert responses[3].content.role == "model"
|
||||
assert responses[3].content.parts[0].function_call.name == "test_function"
|
||||
assert responses[3].content.parts[0].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
60
tests/unittests/models/test_models.py
Normal file
60
tests/unittests/models/test_models.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import models
|
||||
from google.adk.models.anthropic_llm import Claude
|
||||
from google.adk.models.google_llm import Gemini
|
||||
from google.adk.models.registry import LLMRegistry
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'model_name',
|
||||
[
|
||||
'gemini-1.5-flash',
|
||||
'gemini-1.5-flash-001',
|
||||
'gemini-1.5-flash-002',
|
||||
'gemini-1.5-pro',
|
||||
'gemini-1.5-pro-001',
|
||||
'gemini-1.5-pro-002',
|
||||
'gemini-2.0-flash-exp',
|
||||
'projects/123456/locations/us-central1/endpoints/123456', # finetuned vertex gemini endpoint
|
||||
'projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp', # vertex gemini long name
|
||||
],
|
||||
)
|
||||
def test_match_gemini_family(model_name):
|
||||
assert models.LLMRegistry.resolve(model_name) is Gemini
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'model_name',
|
||||
[
|
||||
'claude-3-5-haiku@20241022',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
'claude-3-5-sonnet@20240620',
|
||||
'claude-3-haiku@20240307',
|
||||
'claude-3-opus@20240229',
|
||||
'claude-3-sonnet@20240229',
|
||||
],
|
||||
)
|
||||
def test_match_claude_family(model_name):
|
||||
LLMRegistry.register(Claude)
|
||||
|
||||
assert models.LLMRegistry.resolve(model_name) is Claude
|
||||
|
||||
|
||||
def test_non_exist_model():
|
||||
with pytest.raises(ValueError) as e_info:
|
||||
models.LLMRegistry.resolve('non-exist-model')
|
||||
assert 'Model non-exist-model not found.' in str(e_info.value)
|
||||
14
tests/unittests/sessions/__init__.py
Normal file
14
tests/unittests/sessions/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
227
tests/unittests/sessions/test_session_service.py
Normal file
227
tests/unittests/sessions/test_session_service.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import pytest
|
||||
|
||||
from google.adk.events import Event
|
||||
from google.adk.events import EventActions
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.genai import types
|
||||
|
||||
|
||||
class SessionServiceType(enum.Enum):
|
||||
IN_MEMORY = 'IN_MEMORY'
|
||||
DATABASE = 'DATABASE'
|
||||
|
||||
|
||||
def get_session_service(
|
||||
service_type: SessionServiceType = SessionServiceType.IN_MEMORY,
|
||||
):
|
||||
"""Creates a session service for testing."""
|
||||
if service_type == SessionServiceType.DATABASE:
|
||||
return DatabaseSessionService('sqlite:///:memory:')
|
||||
return InMemorySessionService()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_get_empty_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
assert not session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='123'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_get_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
state = {'key': 'value'}
|
||||
|
||||
session = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
assert session.app_name == app_name
|
||||
assert session.user_id == user_id
|
||||
assert session.id
|
||||
assert session.state == state
|
||||
assert (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
session_id = session.id
|
||||
session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
assert (
|
||||
not session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_and_list_sessions(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
|
||||
session_ids = ['session' + str(i) for i in range(5)]
|
||||
for session_id in session_ids:
|
||||
session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
sessions = session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
for i in range(len(sessions)):
|
||||
assert sessions[i].id == session_ids[i]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_session_state(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
user_id_2 = 'user2'
|
||||
session_id_11 = 'session11'
|
||||
session_id_12 = 'session12'
|
||||
session_id_2 = 'session2'
|
||||
state_11 = {'key11': 'value11'}
|
||||
state_12 = {'key12': 'value12'}
|
||||
|
||||
session_11 = session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_11,
|
||||
session_id=session_id_11,
|
||||
)
|
||||
session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_12,
|
||||
session_id=session_id_12,
|
||||
)
|
||||
session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
|
||||
assert session_11.state.get('key11') == 'value11'
|
||||
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
actions=EventActions(
|
||||
state_delta={
|
||||
'app:key': 'value',
|
||||
'user:key1': 'value1',
|
||||
'temp:key': 'temp',
|
||||
'key11': 'value11_new',
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_11, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_11.state.get('app:key') == 'value'
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
session_12 = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_12
|
||||
)
|
||||
# After getting a new instance, the session_12 got the user and app state,
|
||||
# even append_event is not applied to it, temp state has no effect
|
||||
assert session_12.state.get('key12') == 'value12'
|
||||
assert not session_12.state.get('temp:key')
|
||||
|
||||
# The user1's state is not visible to user2, app state is visible
|
||||
session_2 = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
# The change to session_11 is persisted
|
||||
session_11 = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_11
|
||||
)
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_new_session_will_merge_states(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
session_id_1 = 'session1'
|
||||
session_id_2 = 'session2'
|
||||
state_1 = {'key1': 'value1'}
|
||||
|
||||
session_1 = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
|
||||
)
|
||||
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='text')]),
|
||||
actions=EventActions(
|
||||
state_delta={
|
||||
'app:key': 'value',
|
||||
'user:key1': 'value1',
|
||||
'temp:key': 'temp',
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_1, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_1.state.get('app:key') == 'value'
|
||||
assert session_1.state.get('key1') == 'value1'
|
||||
assert session_1.state.get('user:key1') == 'value1'
|
||||
assert not session_1.state.get('temp:key')
|
||||
|
||||
session_2 = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
|
||||
)
|
||||
# Session 2 has the persisted states
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
assert session_2.state.get('user:key1') == 'value1'
|
||||
assert not session_2.state.get('key1')
|
||||
assert not session_2.state.get('temp:key')
|
||||
246
tests/unittests/sessions/test_vertex_ai_session_service.py
Normal file
246
tests/unittests/sessions/test_vertex_ai_session_service.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import this
|
||||
from typing import Any
|
||||
import uuid
|
||||
from dateutil.parser import isoparse
|
||||
from google.adk.events import Event
|
||||
from google.adk.events import EventActions
|
||||
from google.adk.sessions import Session
|
||||
from google.adk.sessions import VertexAiSessionService
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
MOCK_SESSION_JSON_1 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/1'
|
||||
),
|
||||
'createTime': '2024-12-12T12:12:12.123456Z',
|
||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||
'sessionState': {
|
||||
'key': {'value': 'test_value'},
|
||||
},
|
||||
'userId': 'user',
|
||||
}
|
||||
MOCK_SESSION_JSON_2 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/2'
|
||||
),
|
||||
'updateTime': '2024-12-13T12:12:12.123456Z',
|
||||
'userId': 'user',
|
||||
}
|
||||
MOCK_SESSION_JSON_3 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/3'
|
||||
),
|
||||
'updateTime': '2024-12-14T12:12:12.123456Z',
|
||||
'userId': 'user2',
|
||||
}
|
||||
MOCK_EVENT_JSON = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/test_engine/sessions/1/events/123'
|
||||
),
|
||||
'invocationId': '123',
|
||||
'author': 'user',
|
||||
'timestamp': '2024-12-12T12:12:12.123456Z',
|
||||
'content': {
|
||||
'parts': [
|
||||
{'text': 'test_content'},
|
||||
],
|
||||
},
|
||||
'actions': {
|
||||
'stateDelta': {
|
||||
'key': {'value': 'test_value'},
|
||||
},
|
||||
'transferAgent': 'agent',
|
||||
},
|
||||
'eventMetadata': {
|
||||
'partial': False,
|
||||
'turnComplete': True,
|
||||
'interrupted': False,
|
||||
'branch': '',
|
||||
'longRunningToolIds': ['tool1'],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
MOCK_SESSION = Session(
|
||||
app_name='123',
|
||||
user_id='user',
|
||||
id='1',
|
||||
state=MOCK_SESSION_JSON_1['sessionState'],
|
||||
last_update_time=isoparse(MOCK_SESSION_JSON_1['updateTime']).timestamp(),
|
||||
events=[
|
||||
Event(
|
||||
id='123',
|
||||
invocation_id='123',
|
||||
author='user',
|
||||
timestamp=isoparse(MOCK_EVENT_JSON[0]['timestamp']).timestamp(),
|
||||
content=types.Content(parts=[types.Part(text='test_content')]),
|
||||
actions=EventActions(
|
||||
transfer_to_agent='agent',
|
||||
state_delta={'key': {'value': 'test_value'}},
|
||||
),
|
||||
partial=False,
|
||||
turn_complete=True,
|
||||
interrupted=False,
|
||||
branch='',
|
||||
long_running_tool_ids={'tool1'},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
|
||||
class MockApiClient:
|
||||
"""Mocks the API Client."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes MockClient."""
|
||||
this.session_dict: dict[str, Any] = {}
|
||||
this.event_dict: dict[str, list[Any]] = {}
|
||||
|
||||
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
|
||||
"""Mocks the API Client request method."""
|
||||
if http_method == 'GET':
|
||||
if re.match(SESSION_REGEX, path):
|
||||
match = re.match(SESSION_REGEX, path)
|
||||
if match:
|
||||
session_id = match.group(2)
|
||||
if session_id in self.session_dict:
|
||||
return self.session_dict[session_id]
|
||||
else:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
elif re.match(SESSIONS_REGEX, path):
|
||||
return {
|
||||
'sessions': self.session_dict.values(),
|
||||
}
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
return {'sessionEvents': self.event_dict[match.group(2)]}
|
||||
elif re.match(LRO_REGEX, path):
|
||||
return {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/123'
|
||||
),
|
||||
'done': True,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported path: {path}')
|
||||
elif http_method == 'POST':
|
||||
id = str(uuid.uuid4())
|
||||
self.session_dict[id] = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ id
|
||||
),
|
||||
'userId': request_dict['user_id'],
|
||||
'sessionState': request_dict.get('sessionState', {}),
|
||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||
}
|
||||
return {
|
||||
'name': (
|
||||
'projects/test_project/locations/test_location/'
|
||||
'reasoningEngines/test_engine/sessions/123'
|
||||
),
|
||||
'done': False,
|
||||
}
|
||||
elif http_method == 'DELETE':
|
||||
match = re.match(SESSION_REGEX, path)
|
||||
if match:
|
||||
self.session_dict.pop(match.group(2))
|
||||
else:
|
||||
raise ValueError(f'Unsupported http method: {http_method}')
|
||||
|
||||
|
||||
def mock_vertex_ai_session_service():
|
||||
"""Creates a mock Vertex AI Session service for testing."""
|
||||
service = VertexAiSessionService(
|
||||
project='test-project', location='test-location'
|
||||
)
|
||||
service.api_client = MockApiClient()
|
||||
service.api_client.session_dict = {
|
||||
'1': MOCK_SESSION_JSON_1,
|
||||
'2': MOCK_SESSION_JSON_2,
|
||||
'3': MOCK_SESSION_JSON_3,
|
||||
}
|
||||
service.api_client.event_dict = {
|
||||
'1': MOCK_EVENT_JSON,
|
||||
}
|
||||
return service
|
||||
|
||||
|
||||
def test_get_empty_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='0'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 0'
|
||||
|
||||
|
||||
def test_get_and_delete_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
== MOCK_SESSION
|
||||
)
|
||||
|
||||
session_service.delete_session(app_name='123', user_id='user', session_id='1')
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state={'key': 'value'}
|
||||
)
|
||||
assert session.state == {'key': 'value'}
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
14
tests/unittests/streaming/__init__.py
Normal file
14
tests/unittests/streaming/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
50
tests/unittests/streaming/test_streaming.py
Normal file
50
tests/unittests/streaming/test_streaming.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents import LiveRequestQueue
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Streaming is hanging.')
|
||||
def test_streaming():
|
||||
response1 = LlmResponse(
|
||||
turn_complete=True,
|
||||
)
|
||||
|
||||
mock_model = utils.MockModel.create([response1])
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(
|
||||
root_agent=root_agent, response_modalities=['AUDIO']
|
||||
)
|
||||
live_request_queue = LiveRequestQueue()
|
||||
live_request_queue.send_realtime(
|
||||
blob=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
|
||||
)
|
||||
res_events = runner.run_live(live_request_queue)
|
||||
|
||||
assert res_events is not None, 'Expected a list of events, got None.'
|
||||
assert (
|
||||
len(res_events) > 0
|
||||
), 'Expected at least one response, but got an empty list.'
|
||||
14
tests/unittests/tools/__init__.py
Normal file
14
tests/unittests/tools/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
499
tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
Normal file
@@ -0,0 +1,499 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
# Mock data for API responses
|
||||
MOCK_API_LIST = {
|
||||
"apis": [
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api1"},
|
||||
{"name": "projects/test-project/locations/us-central1/apis/api2"},
|
||||
]
|
||||
}
|
||||
MOCK_API_DETAIL = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
],
|
||||
}
|
||||
MOCK_API_VERSION = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"specs": [
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
],
|
||||
}
|
||||
MOCK_SPEC_CONTENT = {"contents": base64.b64encode(b"spec content").decode()}
|
||||
|
||||
|
||||
# Test cases
|
||||
class TestAPIHubClient:
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
return APIHubClient(access_token="mocked_token")
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_config(self):
|
||||
return json.dumps({
|
||||
"type": "service_account",
|
||||
"project_id": "test",
|
||||
"token_uri": "test.com",
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "1234",
|
||||
})
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_LIST
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == MOCK_API_LIST["apis"]
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"apis": []}
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
apis = client.list_apis("test-project", "us-central1")
|
||||
assert apis == []
|
||||
|
||||
@patch("requests.get")
|
||||
def test_list_apis_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
client.list_apis("test-project", "us-central1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_DETAIL
|
||||
mock_get.return_value.status_code = 200
|
||||
api = client.get_api(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert api == MOCK_API_DETAIL
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api("projects/test-project/locations/us-central1/apis/api1")
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_API_VERSION
|
||||
mock_get.return_value.status_code = 200
|
||||
api_version = client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert api_version == MOCK_API_VERSION
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_api_version_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_api_version(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == "spec content"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1:contents",
|
||||
headers={
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"Authorization": "Bearer mocked_token",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_empty(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {"contents": ""}
|
||||
mock_get.return_value.status_code = 200
|
||||
spec_content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert spec_content == ""
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_error(self, mock_get, client):
|
||||
mock_get.return_value.raise_for_status.side_effect = HTTPError
|
||||
with pytest.raises(HTTPError):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected",
|
||||
[
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1?project=test-project",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"/projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1",
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # Added trailing slashes
|
||||
"projects/test-project/locations/us-central1/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/us-central1/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
( # case location name
|
||||
"projects/test-project/locations/LOCATION/apis/api1/",
|
||||
(
|
||||
"projects/test-project/locations/LOCATION/apis/api1",
|
||||
None,
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
(
|
||||
"projects/p1/locations/l1/apis/a1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1",
|
||||
"projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name(self, client, url_or_path, expected):
|
||||
result = client._extract_resource_name(url_or_path)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url_or_path, expected_error_message",
|
||||
[
|
||||
(
|
||||
"invalid-path",
|
||||
"Project ID not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project",
|
||||
"Location not found in URL or path in APIHubClient.",
|
||||
),
|
||||
(
|
||||
"projects/test-project/locations/us-central1",
|
||||
"API id not found in URL or path in APIHubClient.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_resource_name_invalid(
|
||||
self, client, url_or_path, expected_error_message
|
||||
):
|
||||
with pytest.raises(ValueError, match=expected_error_message):
|
||||
client._extract_resource_name(url_or_path)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_default_credential(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient()
|
||||
token = client._get_access_token()
|
||||
assert token == "default_token"
|
||||
mock_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
|
||||
)
|
||||
def test_get_access_token_use_configured_service_account(
|
||||
self,
|
||||
mock_from_service_account_info,
|
||||
mock_default_service_credential,
|
||||
service_account_config,
|
||||
):
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.token = "default_token"
|
||||
mock_default_service_credential.return_value = (
|
||||
mock_credential,
|
||||
"project_id",
|
||||
)
|
||||
mock_config_credential = MagicMock()
|
||||
mock_config_credential.token = "config_token"
|
||||
mock_from_service_account_info.return_value = mock_config_credential
|
||||
|
||||
client = APIHubClient(service_account_json=service_account_config)
|
||||
token = client._get_access_token()
|
||||
|
||||
assert token == "config_token"
|
||||
mock_from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_config),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_config_credential.refresh.assert_called_once()
|
||||
assert client.credential_cache == mock_config_credential
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_not_expired_use_cached_token(
|
||||
self, mock_default_credential
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
|
||||
client = APIHubClient()
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Reuse cache
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = False
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_credentials.refresh.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_expired_refresh(self, mock_default_credential):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "default_service_account_token"
|
||||
mock_default_credential.return_value = (mock_credentials, "")
|
||||
client = APIHubClient()
|
||||
|
||||
# Call #1: Setup cache
|
||||
token = client._get_access_token()
|
||||
assert token == "default_service_account_token"
|
||||
mock_default_credential.assert_called_once()
|
||||
|
||||
# Call #2: Cache expired
|
||||
mock_credentials.reset_mock()
|
||||
mock_credentials.expired = True
|
||||
token = client._get_access_token()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
assert token == "default_service_account_token"
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
|
||||
)
|
||||
def test_get_access_token_no_credentials(
|
||||
self, mock_default_service_credential
|
||||
):
|
||||
mock_default_service_credential.return_value = (None, None)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account or an access token to API Hub"
|
||||
" client."
|
||||
),
|
||||
):
|
||||
# no service account client
|
||||
APIHubClient()._get_access_token()
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_api_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL), # For get_api
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
# Check calls - get_api, get_api_version, then get_spec_content
|
||||
assert mock_get.call_count == 3
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_version_level(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_API_VERSION
|
||||
), # For get_api_version
|
||||
MagicMock(
|
||||
status_code=200, json=lambda: MOCK_SPEC_CONTENT
|
||||
), # For get_spec_content
|
||||
]
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
assert mock_get.call_count == 2 # get_api_version and get_spec_content
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_spec_level(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
|
||||
mock_get.return_value.status_code = 200
|
||||
|
||||
content = client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
|
||||
)
|
||||
assert content == "spec content"
|
||||
mock_get.assert_called_once() # Only get_spec_content should be called
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_versions(self, mock_get, client):
|
||||
mock_get.return_value.json.return_value = {
|
||||
"name": "projects/test-project/locations/us-central1/apis/api1",
|
||||
"versions": [],
|
||||
} # No versions
|
||||
mock_get.return_value.status_code = 200
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No versions found in API Hub resource:"
|
||||
" projects/test-project/locations/us-central1/apis/api1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_no_specs(self, mock_get, client):
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL),
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"name": (
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
"specs": [],
|
||||
},
|
||||
), # No specs
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No specs found in API Hub version:"
|
||||
" projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
),
|
||||
):
|
||||
client.get_spec_content(
|
||||
"projects/test-project/locations/us-central1/apis/api1/versions/v1"
|
||||
)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_get_spec_content_invalid_path(self, mock_get, client):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Project ID not found in URL or path in APIHubClient. Input"
|
||||
" path is 'invalid-path'."
|
||||
),
|
||||
):
|
||||
client.get_spec_content("invalid-path")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
204
tests/unittests/tools/apihub_tool/test_apihub_toolset.py
Normal file
204
tests/unittests/tools/apihub_tool/test_apihub_toolset.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.apihub_tool.apihub_toolset import APIHubToolset
|
||||
from google.adk.tools.apihub_tool.clients.apihub_client import BaseAPIHubClient
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
class MockAPIHubClient(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: Mock API
|
||||
description: Mock API Description
|
||||
paths:
|
||||
/test:
|
||||
get:
|
||||
summary: Test GET endpoint
|
||||
operationId: testGet
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
|
||||
# Fixture for a basic APIHubToolset
|
||||
@pytest.fixture
|
||||
def basic_apihub_toolset():
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource', apihub_client=apihub_client
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
# Fixture for an APIHubToolset with lazy loading
|
||||
@pytest.fixture
|
||||
def lazy_apihub_toolset():
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
lazy_load_spec=True,
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
# Fixture for auth scheme
|
||||
@pytest.fixture
|
||||
def mock_auth_scheme():
|
||||
return MagicMock(spec=AuthScheme)
|
||||
|
||||
|
||||
# Fixture for auth credential
|
||||
@pytest.fixture
|
||||
def mock_auth_credential():
|
||||
return MagicMock(spec=AuthCredential)
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_apihub_toolset_initialization(basic_apihub_toolset):
|
||||
assert basic_apihub_toolset.name == 'mock_api'
|
||||
assert basic_apihub_toolset.description == 'Mock API Description'
|
||||
assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
|
||||
assert not basic_apihub_toolset.lazy_load_spec
|
||||
assert len(basic_apihub_toolset.generated_tools) == 1
|
||||
assert 'test_get' in basic_apihub_toolset.generated_tools
|
||||
|
||||
|
||||
def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
|
||||
assert lazy_apihub_toolset.lazy_load_spec
|
||||
assert not lazy_apihub_toolset.generated_tools
|
||||
|
||||
tools = lazy_apihub_toolset.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
|
||||
|
||||
|
||||
def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
|
||||
spec = """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
paths:
|
||||
/empty_desc_test:
|
||||
delete:
|
||||
summary: Test DELETE endpoint
|
||||
operationId: emptyDescTest
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return spec
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
toolset = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
|
||||
assert toolset.name == 'unnamed'
|
||||
|
||||
|
||||
def test_apihub_toolset_empty_description_in_spec():
|
||||
spec = """
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: Empty Description API
|
||||
paths:
|
||||
/empty_desc_test:
|
||||
delete:
|
||||
summary: Test DELETE endpoint
|
||||
operationId: emptyDescTest
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
"""
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return spec
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
toolset = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
|
||||
assert toolset.name == 'empty_description_api'
|
||||
assert toolset.description == ''
|
||||
|
||||
|
||||
def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
|
||||
apihub_client = MockAPIHubClient()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
auth_scheme=mock_auth_scheme,
|
||||
auth_credential=mock_auth_credential,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
assert len(tools) == 1
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_lazy_load_empty_spec():
|
||||
|
||||
class MockAPIHubClientEmptySpec(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return ''
|
||||
|
||||
apihub_client = MockAPIHubClientEmptySpec()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
lazy_load_spec=True,
|
||||
)
|
||||
tools = tool.get_tools()
|
||||
assert not tools
|
||||
|
||||
|
||||
def test_apihub_toolset_get_tools_invalid_yaml():
|
||||
|
||||
class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
|
||||
|
||||
def get_spec_content(self, apihub_resource_name: str) -> str:
|
||||
return '{invalid yaml' # Return invalid YAML
|
||||
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
apihub_client = MockAPIHubClientInvalidYAML()
|
||||
tool = APIHubToolset(
|
||||
apihub_resource_name='test_resource',
|
||||
apihub_client=apihub_client,
|
||||
)
|
||||
tool.get_tools()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
@@ -0,0 +1,600 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
import google.auth
|
||||
import pytest
|
||||
import requests
|
||||
from requests import exceptions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_name():
|
||||
return "test-connection"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
creds = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
creds.token = "test_token"
|
||||
creds.expired = False
|
||||
return creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_request():
|
||||
return mock.create_autospec(google.auth.transport.requests.Request)
|
||||
|
||||
|
||||
class TestConnectionsClient:
|
||||
|
||||
def test_initialization(self, project, location, connection_name):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(
|
||||
project, location, connection_name, json.dumps(credentials)
|
||||
)
|
||||
assert client.project == project
|
||||
assert client.location == location
|
||||
assert client.connection == connection_name
|
||||
assert client.connector_url == "https://connectors.googleapis.com"
|
||||
assert client.service_account_json == json.dumps(credentials)
|
||||
assert client.credential_cache is None
|
||||
|
||||
def test_execute_api_call_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"data": "test"}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
response = client._execute_api_call("https://test.url")
|
||||
assert response.json() == {"data": "test"}
|
||||
requests.get.assert_called_once_with(
|
||||
"https://test.url",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {mock_credentials.token}",
|
||||
},
|
||||
)
|
||||
|
||||
def test_execute_api_call_credential_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_get_access_token",
|
||||
side_effect=google.auth.exceptions.DefaultCredentialsError("Test"),
|
||||
):
|
||||
with pytest.raises(PermissionError, match="Credentials error: Test"):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, response_text",
|
||||
[(404, "Not Found"), (400, "Bad Request")],
|
||||
)
|
||||
def test_execute_api_call_request_error_not_found_or_bad_request(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
status_code,
|
||||
response_text,
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
f"HTTP error {status_code}: {response_text}"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
with pytest.raises(
|
||||
ValueError, match="Invalid request. Please check the provided"
|
||||
):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_execute_api_call_other_request_error(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
"Internal Server Error"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch("requests.get", return_value=mock_response):
|
||||
with pytest.raises(ValueError, match="Request error: "):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_execute_api_call_unexpected_error(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_get_access_token", return_value=mock_credentials.token
|
||||
), mock.patch(
|
||||
"requests.get", side_effect=Exception("Something went wrong")
|
||||
):
|
||||
with pytest.raises(
|
||||
Exception, match="An unexpected error occurred: Something went wrong"
|
||||
):
|
||||
client._execute_api_call("https://test.url")
|
||||
|
||||
def test_get_connection_details_success_with_host(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"serviceDirectory": "test_service",
|
||||
"host": "test.host",
|
||||
"tlsServiceDirectory": "tls_test_service",
|
||||
"authOverrideEnabled": True,
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_response
|
||||
):
|
||||
details = client.get_connection_details()
|
||||
assert details == {
|
||||
"serviceName": "tls_test_service",
|
||||
"host": "test.host",
|
||||
"authOverrideEnabled": True,
|
||||
}
|
||||
|
||||
def test_get_connection_details_success_without_host(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"serviceDirectory": "test_service",
|
||||
"authOverrideEnabled": False,
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_response
|
||||
):
|
||||
details = client.get_connection_details()
|
||||
assert details == {
|
||||
"serviceName": "test_service",
|
||||
"host": "",
|
||||
"authOverrideEnabled": False,
|
||||
}
|
||||
|
||||
def test_get_connection_details_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_connection_details()
|
||||
|
||||
def test_get_entity_schema_and_operations_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response_initial = mock.MagicMock()
|
||||
mock_execute_response_initial.status_code = 200
|
||||
mock_execute_response_initial.json.return_value = {
|
||||
"name": "operations/test_op"
|
||||
}
|
||||
|
||||
mock_execute_response_poll_done = mock.MagicMock()
|
||||
mock_execute_response_poll_done.status_code = 200
|
||||
mock_execute_response_poll_done.json.return_value = {
|
||||
"done": True,
|
||||
"response": {
|
||||
"jsonSchema": {"type": "object"},
|
||||
"operations": ["LIST", "GET"],
|
||||
},
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_execute_api_call",
|
||||
side_effect=[
|
||||
mock_execute_response_initial,
|
||||
mock_execute_response_poll_done,
|
||||
],
|
||||
):
|
||||
schema, operations = client.get_entity_schema_and_operations("entity1")
|
||||
assert schema == {"type": "object"}
|
||||
assert operations == ["LIST", "GET"]
|
||||
assert (
|
||||
mock.call(
|
||||
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getEntityType?entityId=entity1"
|
||||
)
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
assert (
|
||||
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
|
||||
def test_get_entity_schema_and_operations_no_operation_id(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response = mock.MagicMock()
|
||||
mock_execute_response.status_code = 200
|
||||
mock_execute_response.json.return_value = {}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_execute_response
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Failed to get entity schema and operations for entity: entity1"
|
||||
),
|
||||
):
|
||||
client.get_entity_schema_and_operations("entity1")
|
||||
|
||||
def test_get_entity_schema_and_operations_execute_api_call_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_entity_schema_and_operations("entity1")
|
||||
|
||||
def test_get_action_schema_success(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response_initial = mock.MagicMock()
|
||||
mock_execute_response_initial.status_code = 200
|
||||
mock_execute_response_initial.json.return_value = {
|
||||
"name": "operations/test_op"
|
||||
}
|
||||
|
||||
mock_execute_response_poll_done = mock.MagicMock()
|
||||
mock_execute_response_poll_done.status_code = 200
|
||||
mock_execute_response_poll_done.json.return_value = {
|
||||
"done": True,
|
||||
"response": {
|
||||
"inputJsonSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputJsonSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"description": "Test Action Description",
|
||||
"displayName": "TestAction",
|
||||
},
|
||||
}
|
||||
|
||||
with mock.patch.object(
|
||||
client,
|
||||
"_execute_api_call",
|
||||
side_effect=[
|
||||
mock_execute_response_initial,
|
||||
mock_execute_response_poll_done,
|
||||
],
|
||||
):
|
||||
schema = client.get_action_schema("action1")
|
||||
assert schema == {
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"description": "Test Action Description",
|
||||
"displayName": "TestAction",
|
||||
}
|
||||
assert (
|
||||
mock.call(
|
||||
f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getAction?actionId=action1"
|
||||
)
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
assert (
|
||||
mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
|
||||
in client._execute_api_call.mock_calls
|
||||
)
|
||||
|
||||
def test_get_action_schema_no_operation_id(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
mock_execute_response = mock.MagicMock()
|
||||
mock_execute_response.status_code = 200
|
||||
mock_execute_response.json.return_value = {}
|
||||
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", return_value=mock_execute_response
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to get action schema for action: action1"
|
||||
):
|
||||
client.get_action_schema("action1")
|
||||
|
||||
def test_get_action_schema_execute_api_call_error(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
with mock.patch.object(
|
||||
client, "_execute_api_call", side_effect=ValueError("Request error")
|
||||
):
|
||||
with pytest.raises(ValueError, match="Request error"):
|
||||
client.get_action_schema("action1")
|
||||
|
||||
def test_get_connector_base_spec(self):
|
||||
spec = ConnectionsClient.get_connector_base_spec()
|
||||
assert "openapi" in spec
|
||||
assert spec["info"]["title"] == "ExecuteConnection"
|
||||
assert "components" in spec
|
||||
assert "schemas" in spec["components"]
|
||||
assert "operation" in spec["components"]["schemas"]
|
||||
|
||||
def test_get_action_operation(self):
|
||||
operation = ConnectionsClient.get_action_operation(
|
||||
"TestAction", "EXECUTE_ACTION", "TestActionDisplayName", "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "TestActionDisplayName"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_TestActionDisplayName"
|
||||
|
||||
def test_list_operation(self):
|
||||
operation = ConnectionsClient.list_operation(
|
||||
"Entity1", '{"type": "object"}', "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "List Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_list_Entity1"
|
||||
|
||||
def test_get_operation_static(self):
|
||||
operation = ConnectionsClient.get_operation(
|
||||
"Entity1", '{"type": "object"}', "test_tool"
|
||||
)
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Get Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_get_Entity1"
|
||||
|
||||
def test_create_operation(self):
|
||||
operation = ConnectionsClient.create_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Create Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_create_Entity1"
|
||||
|
||||
def test_update_operation(self):
|
||||
operation = ConnectionsClient.update_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Update Entity1"
|
||||
assert "operationId" in operation["post"]
|
||||
assert operation["post"]["operationId"] == "test_tool_update_Entity1"
|
||||
|
||||
def test_delete_operation(self):
|
||||
operation = ConnectionsClient.delete_operation("Entity1", "test_tool")
|
||||
assert "post" in operation
|
||||
assert operation["post"]["summary"] == "Delete Entity1"
|
||||
assert operation["post"]["operationId"] == "test_tool_delete_Entity1"
|
||||
|
||||
def test_create_operation_request(self):
|
||||
schema = ConnectionsClient.create_operation_request("Entity1")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorInputPayload" in schema["properties"]
|
||||
|
||||
def test_update_operation_request(self):
|
||||
schema = ConnectionsClient.update_operation_request("Entity1")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_get_operation_request_static(self):
|
||||
schema = ConnectionsClient.get_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_delete_operation_request(self):
|
||||
schema = ConnectionsClient.delete_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "entityId" in schema["properties"]
|
||||
|
||||
def test_list_operation_request(self):
|
||||
schema = ConnectionsClient.list_operation_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "filterClause" in schema["properties"]
|
||||
|
||||
def test_action_request(self):
|
||||
schema = ConnectionsClient.action_request("TestAction")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorInputPayload" in schema["properties"]
|
||||
|
||||
def test_action_response(self):
|
||||
schema = ConnectionsClient.action_response("TestAction")
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "connectorOutputPayload" in schema["properties"]
|
||||
|
||||
def test_execute_custom_query_request(self):
|
||||
schema = ConnectionsClient.execute_custom_query_request()
|
||||
assert "type" in schema
|
||||
assert schema["type"] == "object"
|
||||
assert "properties" in schema
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
def test_connector_payload(self):
|
||||
client = ConnectionsClient("test-project", "us-central1", "test-connection")
|
||||
schema = client.connector_payload(
|
||||
json_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": ["null", "string"],
|
||||
"description": "description",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
assert schema == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"nullable": True,
|
||||
"description": "description",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def test_get_access_token_uses_cached_token(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
credentials = {"email": "test@example.com"}
|
||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||
client.credential_cache = mock_credentials
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_with_service_account_credentials(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "test_key",
|
||||
})
|
||||
client = ConnectionsClient(
|
||||
project, location, connection_name, service_account_json
|
||||
)
|
||||
mock_creds = mock.create_autospec(google.oauth2.service_account.Credentials)
|
||||
mock_creds.token = "sa_token"
|
||||
mock_creds.expired = False
|
||||
|
||||
with mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=mock_creds,
|
||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
||||
token = client._get_access_token()
|
||||
assert token == "sa_token"
|
||||
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_creds.refresh.assert_called_once()
|
||||
|
||||
def test_get_access_token_with_default_credentials(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_no_valid_credentials(
|
||||
self, project, location, connection_name
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(None, None),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Please provide a service account that has the required"
|
||||
" permissions"
|
||||
),
|
||||
):
|
||||
client._get_access_token()
|
||||
|
||||
def test_get_access_token_refreshes_expired_token(
|
||||
self, project, location, connection_name, mock_credentials
|
||||
):
|
||||
client = ConnectionsClient(project, location, connection_name, None)
|
||||
mock_credentials.expired = True
|
||||
mock_credentials.token = "old_token"
|
||||
mock_credentials.refresh.return_value = None
|
||||
|
||||
client.credential_cache = mock_credentials
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
):
|
||||
# Mock the refresh method directly on the instance within the context
|
||||
with mock.patch.object(mock_credentials, "refresh") as mock_refresh:
|
||||
mock_credentials.token = "new_token" # Set the expected new token
|
||||
token = client._get_access_token()
|
||||
assert token == "new_token"
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -0,0 +1,630 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import pytest
|
||||
import requests
|
||||
from requests import exceptions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_name():
|
||||
return "test-integration"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trigger_name():
|
||||
return "test-trigger"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_name():
|
||||
return "test-connection"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials():
|
||||
creds = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
creds.token = "test_token"
|
||||
return creds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_request():
|
||||
return mock.create_autospec(Request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connections_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.ConnectionsClient"
|
||||
) as mock_client:
|
||||
mock_instance = mock.create_autospec(ConnectionsClient)
|
||||
mock_client.return_value = mock_instance
|
||||
yield mock_client
|
||||
|
||||
|
||||
class TestIntegrationClient:
|
||||
|
||||
def test_initialization(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations={"entity": ["LIST"]},
|
||||
actions=["action1"],
|
||||
service_account_json=json.dumps({"email": "test@example.com"}),
|
||||
)
|
||||
assert client.project == project
|
||||
assert client.location == location
|
||||
assert client.integration == integration_name
|
||||
assert client.trigger == trigger_name
|
||||
assert client.connection == connection_name
|
||||
assert client.entity_operations == {"entity": ["LIST"]}
|
||||
assert client.actions == ["action1"]
|
||||
assert client.service_account_json == json.dumps(
|
||||
{"email": "test@example.com"}
|
||||
)
|
||||
assert client.credential_cache is None
|
||||
|
||||
def test_get_openapi_spec_for_integration_success(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_integration()
|
||||
assert spec == expected_spec
|
||||
requests.post.assert_called_once_with(
|
||||
f"https://{location}-integrations.googleapis.com/v1/projects/{project}/locations/{location}:generateOpenApiSpec",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {mock_credentials.token}",
|
||||
},
|
||||
json={
|
||||
"apiTriggerResources": [{
|
||||
"integrationResource": integration_name,
|
||||
"triggerId": [trigger_name],
|
||||
}],
|
||||
"fileFormat": "JSON",
|
||||
},
|
||||
)
|
||||
|
||||
def test_get_openapi_spec_for_integration_credential_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_connections_client,
|
||||
):
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
side_effect=ValueError(
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"An unexpected error occurred: Please provide a service account"
|
||||
" that has the required permissions to access the connection."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code, response_text",
|
||||
[(404, "Not Found"), (400, "Bad Request"), (404, ""), (400, "")],
|
||||
)
|
||||
def test_get_openapi_spec_for_integration_request_error_not_found_or_bad_request(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
status_code,
|
||||
response_text,
|
||||
mock_connections_client,
|
||||
):
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
f"HTTP error {status_code}: {response_text}"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Invalid request. Please check the provided values of"
|
||||
f" project\\({project}\\), location\\({location}\\),"
|
||||
f" integration\\({integration_name}\\) and"
|
||||
f" trigger\\({trigger_name}\\)."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_integration_other_request_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.raise_for_status.side_effect = exceptions.HTTPError(
|
||||
"Internal Server Error"
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch("requests.post", return_value=mock_response):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Request error: "):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_integration_unexpected_error(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
with mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
), mock.patch(
|
||||
"requests.post", side_effect=Exception("Something went wrong")
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=None,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
Exception, match="An unexpected error occurred: Something went wrong"
|
||||
):
|
||||
client.get_openapi_spec_for_integration()
|
||||
|
||||
def test_get_openapi_spec_for_connection_no_entity_operations_or_actions(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"No entity operations or actions provided. Please provide at least"
|
||||
" one of them."
|
||||
),
|
||||
):
|
||||
client.get_openapi_spec_for_connection()
|
||||
|
||||
def test_get_openapi_spec_for_connection_with_entity_operations(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
entity_operations = {"entity1": ["LIST", "GET"]}
|
||||
|
||||
mock_connections_client_instance = mock_connections_client.return_value
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}},
|
||||
["LIST", "GET"],
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.list_operation.return_value = {"get": {}}
|
||||
mock_connections_client_instance.list_operation_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.get_operation.return_value = {"get": {}}
|
||||
mock_connections_client_instance.get_operation_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_connection()
|
||||
assert "paths" in spec
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#list_entity1"
|
||||
in spec["paths"]
|
||||
)
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#get_entity1"
|
||||
in spec["paths"]
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.assert_any_call(
|
||||
"entity1"
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.list_operation.assert_called_once()
|
||||
mock_connections_client_instance.get_operation.assert_called_once()
|
||||
|
||||
def test_get_openapi_spec_for_connection_with_actions(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
actions = ["TestAction"]
|
||||
mock_connections_client_instance = (
|
||||
mock_connections_client.return_value
|
||||
) # Corrected line
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_action_schema.return_value = {
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"input": {"type": "string"}},
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"output": {"type": "string"}},
|
||||
},
|
||||
"displayName": "TestAction",
|
||||
}
|
||||
mock_connections_client_instance.connector_payload.side_effect = [
|
||||
{"type": "object"},
|
||||
{"type": "object"},
|
||||
]
|
||||
mock_connections_client_instance.action_request.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.action_response.return_value = {
|
||||
"type": "object"
|
||||
}
|
||||
mock_connections_client_instance.get_action_operation.return_value = {
|
||||
"post": {}
|
||||
}
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=actions,
|
||||
service_account_json=None,
|
||||
)
|
||||
spec = client.get_openapi_spec_for_connection()
|
||||
assert "paths" in spec
|
||||
assert (
|
||||
f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#TestAction"
|
||||
in spec["paths"]
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client_instance.get_connector_base_spec.assert_called_once()
|
||||
mock_connections_client_instance.get_action_schema.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"input": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.connector_payload.assert_any_call(
|
||||
{"type": "object", "properties": {"output": {"type": "string"}}}
|
||||
)
|
||||
mock_connections_client_instance.action_request.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.action_response.assert_called_once_with(
|
||||
"TestAction"
|
||||
)
|
||||
mock_connections_client_instance.get_action_operation.assert_called_once()
|
||||
|
||||
def test_get_openapi_spec_for_connection_invalid_operation(
|
||||
self, project, location, connection_name, mock_connections_client
|
||||
):
|
||||
entity_operations = {"entity1": ["INVALID"]}
|
||||
mock_connections_client_instance = mock_connections_client.return_value
|
||||
mock_connections_client_instance.get_connector_base_spec.return_value = {
|
||||
"components": {"schemas": {}},
|
||||
"paths": {},
|
||||
}
|
||||
mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
|
||||
{"type": "object", "properties": {"id": {"type": "string"}}},
|
||||
["LIST", "GET"],
|
||||
)
|
||||
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=None,
|
||||
trigger=None,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="Invalid operation: INVALID for entity: entity1"
|
||||
):
|
||||
client.get_openapi_spec_for_connection()
|
||||
|
||||
def test_get_access_token_with_service_account_json(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"client_email": "test@example.com",
|
||||
"private_key": "test_key",
|
||||
})
|
||||
mock_creds = mock.create_autospec(service_account.Credentials)
|
||||
mock_creds.token = "sa_token"
|
||||
mock_creds.expired = False
|
||||
|
||||
with mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=mock_creds,
|
||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
token = client._get_access_token()
|
||||
assert token == "sa_token"
|
||||
service_account.Credentials.from_service_account_info.assert_called_once_with(
|
||||
json.loads(service_account_json),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
mock_creds.refresh.assert_called_once()
|
||||
|
||||
def test_get_access_token_with_default_credentials(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials.expired = False
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
token = client._get_access_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_access_token_no_valid_credentials(
|
||||
self, project, location, integration_name, trigger_name, connection_name
|
||||
):
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(None, None),
|
||||
), mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||
return_value=None,
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
try:
|
||||
client._get_access_token()
|
||||
assert False, "ValueError was not raised" # Explicitly fail if no error
|
||||
except ValueError as e:
|
||||
assert (
|
||||
"Please provide a service account that has the required permissions"
|
||||
" to access the connection."
|
||||
in str(e)
|
||||
)
|
||||
|
||||
def test_get_access_token_uses_cached_token(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials.token = "cached_token"
|
||||
mock_credentials.expired = False
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
client.credential_cache = mock_credentials # Simulate a cached credential
|
||||
with mock.patch("google.auth.default") as mock_default, mock.patch(
|
||||
"google.oauth2.service_account.Credentials.from_service_account_info"
|
||||
) as mock_sa:
|
||||
token = client._get_access_token()
|
||||
assert token == "cached_token"
|
||||
mock_default.assert_not_called()
|
||||
mock_sa.assert_not_called()
|
||||
|
||||
def test_get_access_token_refreshes_expired_token(
|
||||
self,
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
connection_name,
|
||||
mock_credentials,
|
||||
):
|
||||
mock_credentials = mock.create_autospec(google.auth.credentials.Credentials)
|
||||
mock_credentials.token = "old_token"
|
||||
mock_credentials.expired = True
|
||||
mock_credentials.refresh.return_value = None
|
||||
mock_credentials.token = "new_token" # Simulate token refresh
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||
return_value=(mock_credentials, "test_project_id"),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
location=location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
connection=connection_name,
|
||||
entity_operations=None,
|
||||
actions=None,
|
||||
service_account_json=None,
|
||||
)
|
||||
client.credential_cache = mock_credentials
|
||||
token = client._get_access_token()
|
||||
assert token == "new_token"
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_integration_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.IntegrationClient"
|
||||
) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connections_client():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.ConnectionsClient"
|
||||
) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openapi_toolset():
|
||||
with mock.patch(
|
||||
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenAPIToolset"
|
||||
) as mock_toolset:
|
||||
mock_toolset_instance = mock.MagicMock()
|
||||
mock_rest_api_tool = mock.MagicMock(spec=rest_api_tool.RestApiTool)
|
||||
mock_rest_api_tool.name = "Test Tool"
|
||||
mock_toolset_instance.get_tools.return_value = [mock_rest_api_tool]
|
||||
mock_toolset.return_value = mock_toolset_instance
|
||||
yield mock_toolset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project():
|
||||
return "test-project"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location():
|
||||
return "us-central1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_spec():
|
||||
return {"openapi": "3.0.0", "info": {"title": "Integration API"}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_spec():
|
||||
return {"openapi": "3.0.0", "info": {"title": "Connection API"}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_details():
|
||||
return {"serviceName": "test-service", "host": "test.host"}
|
||||
|
||||
|
||||
def test_initialization_with_integration_and_trigger(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, integration_name, trigger_name, None, None, None, None
|
||||
)
|
||||
mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once()
|
||||
mock_connections_client.assert_not_called()
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_and_entity_operations(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
entity_operations_list = ["list", "get"]
|
||||
tool_name = "My Connection Tool"
|
||||
tool_instructions = "Use this tool to manage entities."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details
|
||||
)
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project,
|
||||
location,
|
||||
None,
|
||||
None,
|
||||
connection_name,
|
||||
entity_operations_list,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_and_actions(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
connection_details,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
actions_list = ["create", "delete"]
|
||||
tool_name = "My Actions Tool"
|
||||
tool_instructions = "Perform actions using this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = (
|
||||
connection_details
|
||||
)
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
actions=actions_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, None, None, connection_name, None, actions_list, None
|
||||
)
|
||||
mock_connections_client.assert_called_once_with(
|
||||
project, location, connection_name, None
|
||||
)
|
||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
||||
f" {connection_details['host']} and the connection name ="
|
||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
||||
" using this tool. DONOT ask the user for these values as you already"
|
||||
" have those.",
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
assert len(toolset.get_tools()) == 1
|
||||
assert toolset.get_tools()[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_without_required_params(project, location):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, integration="test")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, trigger="test")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Either \\(integration and trigger\\) or \\(connection and"
|
||||
" \\(entity_operations or actions\\)\\) should be provided."
|
||||
),
|
||||
):
|
||||
ApplicationIntegrationToolset(project, location, connection="test")
|
||||
|
||||
|
||||
def test_initialization_with_service_account_credentials(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
service_account_json = json.dumps({
|
||||
"type": "service_account",
|
||||
"project_id": "dummy",
|
||||
"private_key_id": "dummy",
|
||||
"private_key": "dummy",
|
||||
"client_email": "test@example.com",
|
||||
"client_id": "131331543646416",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": (
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
"client_x509_cert_url": (
|
||||
"http://www.googleapis.com/robot/v1/metadata/x509/dummy%40dummy.com"
|
||||
),
|
||||
"universe_domain": "googleapis.com",
|
||||
})
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
integration=integration_name,
|
||||
trigger=trigger_name,
|
||||
service_account_json=service_account_json,
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project,
|
||||
location,
|
||||
integration_name,
|
||||
trigger_name,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
service_account_json,
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
_, kwargs = mock_openapi_toolset.call_args
|
||||
assert isinstance(kwargs["auth_credential"], AuthCredential)
|
||||
assert (
|
||||
kwargs[
|
||||
"auth_credential"
|
||||
].service_account.service_account_credential.client_email
|
||||
== "test@example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_initialization_without_explicit_service_account_credentials(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
mock_integration_client.assert_called_once_with(
|
||||
project, location, integration_name, trigger_name, None, None, None, None
|
||||
)
|
||||
mock_openapi_toolset.assert_called_once()
|
||||
_, kwargs = mock_openapi_toolset.call_args
|
||||
assert isinstance(kwargs["auth_credential"], AuthCredential)
|
||||
assert kwargs["auth_credential"].service_account.use_default_credential
|
||||
|
||||
|
||||
def test_get_tools(
|
||||
project, location, mock_integration_client, mock_openapi_toolset
|
||||
):
|
||||
integration_name = "test-integration"
|
||||
trigger_name = "test-trigger"
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project, location, integration=integration_name, trigger=trigger_name
|
||||
)
|
||||
tools = toolset.get_tools()
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], rest_api_tool.RestApiTool)
|
||||
assert tools[0].name == "Test Tool"
|
||||
|
||||
|
||||
def test_initialization_with_connection_details(
|
||||
project,
|
||||
location,
|
||||
mock_integration_client,
|
||||
mock_connections_client,
|
||||
mock_openapi_toolset,
|
||||
):
|
||||
connection_name = "test-connection"
|
||||
entity_operations_list = ["list"]
|
||||
tool_name = "My Connection Tool"
|
||||
tool_instructions = "Use this tool."
|
||||
mock_connections_client.return_value.get_connection_details.return_value = {
|
||||
"serviceName": "custom-service",
|
||||
"host": "custom.host",
|
||||
}
|
||||
toolset = ApplicationIntegrationToolset(
|
||||
project,
|
||||
location,
|
||||
connection=connection_name,
|
||||
entity_operations=entity_operations_list,
|
||||
tool_name=tool_name,
|
||||
tool_instructions=tool_instructions,
|
||||
)
|
||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||
tool_name,
|
||||
tool_instructions
|
||||
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
|
||||
" connection name ="
|
||||
" projects/test-project/locations/us-central1/connections/test-connection"
|
||||
" when using this tool. DONOT ask the user for these values as you"
|
||||
" already have those.",
|
||||
)
|
||||
13
tests/unittests/tools/google_api_tool/__init__.py
Normal file
13
tests/unittests/tools/google_api_tool/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,657 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.tools.google_api_tool.googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
||||
# Import the converter class
|
||||
from googleapiclient.errors import HttpError
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calendar_api_spec():
|
||||
"""Fixture that provides a mock Google Calendar API spec for testing."""
|
||||
return {
|
||||
"kind": "discovery#restDescription",
|
||||
"id": "calendar:v3",
|
||||
"name": "calendar",
|
||||
"version": "v3",
|
||||
"title": "Google Calendar API",
|
||||
"description": "Accesses the Google Calendar API",
|
||||
"documentationLink": "https://developers.google.com/calendar/",
|
||||
"protocol": "rest",
|
||||
"rootUrl": "https://www.googleapis.com/",
|
||||
"servicePath": "calendar/v3/",
|
||||
"auth": {
|
||||
"oauth2": {
|
||||
"scopes": {
|
||||
"https://www.googleapis.com/auth/calendar": {
|
||||
"description": "Full access to Google Calendar"
|
||||
},
|
||||
"https://www.googleapis.com/auth/calendar.readonly": {
|
||||
"description": "Read-only access to Google Calendar"
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"schemas": {
|
||||
"Calendar": {
|
||||
"type": "object",
|
||||
"description": "A calendar resource",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Calendar summary",
|
||||
"required": True,
|
||||
},
|
||||
"timeZone": {
|
||||
"type": "string",
|
||||
"description": "Calendar timezone",
|
||||
},
|
||||
},
|
||||
},
|
||||
"Event": {
|
||||
"type": "object",
|
||||
"description": "An event resource",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "Event identifier"},
|
||||
"summary": {"type": "string", "description": "Event summary"},
|
||||
"start": {"$ref": "EventDateTime"},
|
||||
"end": {"$ref": "EventDateTime"},
|
||||
"attendees": {
|
||||
"type": "array",
|
||||
"description": "Event attendees",
|
||||
"items": {"$ref": "EventAttendee"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"EventDateTime": {
|
||||
"type": "object",
|
||||
"description": "Date/time for an event",
|
||||
"properties": {
|
||||
"dateTime": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Date/time in RFC3339 format",
|
||||
},
|
||||
"timeZone": {
|
||||
"type": "string",
|
||||
"description": "Timezone for the date/time",
|
||||
},
|
||||
},
|
||||
},
|
||||
"EventAttendee": {
|
||||
"type": "object",
|
||||
"description": "An attendee of an event",
|
||||
"properties": {
|
||||
"email": {"type": "string", "description": "Attendee email"},
|
||||
"responseStatus": {
|
||||
"type": "string",
|
||||
"description": "Response status",
|
||||
"enum": [
|
||||
"needsAction",
|
||||
"declined",
|
||||
"tentative",
|
||||
"accepted",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"resources": {
|
||||
"calendars": {
|
||||
"methods": {
|
||||
"get": {
|
||||
"id": "calendar.calendars.get",
|
||||
"path": "calendars/{calendarId}",
|
||||
"httpMethod": "GET",
|
||||
"description": "Returns metadata for a calendar.",
|
||||
"parameters": {
|
||||
"calendarId": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
"required": True,
|
||||
"location": "path",
|
||||
}
|
||||
},
|
||||
"response": {"$ref": "Calendar"},
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
},
|
||||
"insert": {
|
||||
"id": "calendar.calendars.insert",
|
||||
"path": "calendars",
|
||||
"httpMethod": "POST",
|
||||
"description": "Creates a secondary calendar.",
|
||||
"request": {"$ref": "Calendar"},
|
||||
"response": {"$ref": "Calendar"},
|
||||
"scopes": ["https://www.googleapis.com/auth/calendar"],
|
||||
},
|
||||
},
|
||||
"resources": {
|
||||
"events": {
|
||||
"methods": {
|
||||
"list": {
|
||||
"id": "calendar.events.list",
|
||||
"path": "calendars/{calendarId}/events",
|
||||
"httpMethod": "GET",
|
||||
"description": (
|
||||
"Returns events on the specified calendar."
|
||||
),
|
||||
"parameters": {
|
||||
"calendarId": {
|
||||
"type": "string",
|
||||
"description": "Calendar identifier",
|
||||
"required": True,
|
||||
"location": "path",
|
||||
},
|
||||
"maxResults": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of events returned"
|
||||
),
|
||||
"format": "int32",
|
||||
"minimum": "1",
|
||||
"maximum": "2500",
|
||||
"default": "250",
|
||||
"location": "query",
|
||||
},
|
||||
"orderBy": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Order of the events returned"
|
||||
),
|
||||
"enum": ["startTime", "updated"],
|
||||
"location": "query",
|
||||
},
|
||||
},
|
||||
"response": {"$ref": "Events"},
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter():
|
||||
"""Fixture that provides a basic converter instance."""
|
||||
return GoogleApiToOpenApiConverter("calendar", "v3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_resource(calendar_api_spec):
|
||||
"""Fixture that provides a mock API resource with the test spec."""
|
||||
mock_resource = MagicMock()
|
||||
mock_resource._rootDesc = calendar_api_spec
|
||||
return mock_resource
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepared_converter(converter, calendar_api_spec):
|
||||
"""Fixture that provides a converter with the API spec already set."""
|
||||
converter.google_api_spec = calendar_api_spec
|
||||
return converter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter_with_patched_build(monkeypatch, mock_api_resource):
|
||||
"""Fixture that provides a converter with the build function patched.
|
||||
|
||||
This simulates a successful API spec fetch.
|
||||
"""
|
||||
# Create a mock for the build function
|
||||
mock_build = MagicMock(return_value=mock_api_resource)
|
||||
|
||||
# Patch the build function in the target module
|
||||
monkeypatch.setattr(
|
||||
"google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
|
||||
mock_build,
|
||||
)
|
||||
|
||||
# Create and return a converter instance
|
||||
return GoogleApiToOpenApiConverter("calendar", "v3")
|
||||
|
||||
|
||||
class TestGoogleApiToOpenApiConverter:
|
||||
"""Test suite for the GoogleApiToOpenApiConverter class."""
|
||||
|
||||
def test_init(self, converter):
|
||||
"""Test converter initialization."""
|
||||
assert converter.api_name == "calendar"
|
||||
assert converter.api_version == "v3"
|
||||
assert converter.google_api_resource is None
|
||||
assert converter.google_api_spec is None
|
||||
assert converter.openapi_spec["openapi"] == "3.0.0"
|
||||
assert "info" in converter.openapi_spec
|
||||
assert "paths" in converter.openapi_spec
|
||||
assert "components" in converter.openapi_spec
|
||||
|
||||
def test_fetch_google_api_spec(
|
||||
self, converter_with_patched_build, calendar_api_spec
|
||||
):
|
||||
"""Test fetching Google API specification."""
|
||||
# Call the method
|
||||
converter_with_patched_build.fetch_google_api_spec()
|
||||
|
||||
# Verify the results
|
||||
assert converter_with_patched_build.google_api_spec == calendar_api_spec
|
||||
|
||||
def test_fetch_google_api_spec_error(self, monkeypatch, converter):
|
||||
"""Test error handling when fetching Google API specification."""
|
||||
# Create a mock that raises an error
|
||||
mock_build = MagicMock(
|
||||
side_effect=HttpError(resp=MagicMock(status=404), content=b"Not Found")
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
|
||||
mock_build,
|
||||
)
|
||||
|
||||
# Verify exception is raised
|
||||
with pytest.raises(HttpError):
|
||||
converter.fetch_google_api_spec()
|
||||
|
||||
def test_convert_info(self, prepared_converter):
|
||||
"""Test conversion of basic API information."""
|
||||
# Call the method
|
||||
prepared_converter._convert_info()
|
||||
|
||||
# Verify the results
|
||||
info = prepared_converter.openapi_spec["info"]
|
||||
assert info["title"] == "Google Calendar API"
|
||||
assert info["description"] == "Accesses the Google Calendar API"
|
||||
assert info["version"] == "v3"
|
||||
assert info["termsOfService"] == "https://developers.google.com/calendar/"
|
||||
|
||||
# Check external docs
|
||||
external_docs = prepared_converter.openapi_spec["externalDocs"]
|
||||
assert external_docs["url"] == "https://developers.google.com/calendar/"
|
||||
|
||||
def test_convert_servers(self, prepared_converter):
|
||||
"""Test conversion of server information."""
|
||||
# Call the method
|
||||
prepared_converter._convert_servers()
|
||||
|
||||
# Verify the results
|
||||
servers = prepared_converter.openapi_spec["servers"]
|
||||
assert len(servers) == 1
|
||||
assert servers[0]["url"] == "https://www.googleapis.com/calendar/v3"
|
||||
assert servers[0]["description"] == "calendar v3 API"
|
||||
|
||||
def test_convert_security_schemes(self, prepared_converter):
|
||||
"""Test conversion of security schemes."""
|
||||
# Call the method
|
||||
prepared_converter._convert_security_schemes()
|
||||
|
||||
# Verify the results
|
||||
security_schemes = prepared_converter.openapi_spec["components"][
|
||||
"securitySchemes"
|
||||
]
|
||||
|
||||
# Check OAuth2 configuration
|
||||
assert "oauth2" in security_schemes
|
||||
oauth2 = security_schemes["oauth2"]
|
||||
assert oauth2["type"] == "oauth2"
|
||||
|
||||
# Check OAuth2 scopes
|
||||
scopes = oauth2["flows"]["authorizationCode"]["scopes"]
|
||||
assert "https://www.googleapis.com/auth/calendar" in scopes
|
||||
assert "https://www.googleapis.com/auth/calendar.readonly" in scopes
|
||||
|
||||
# Check API key configuration
|
||||
assert "apiKey" in security_schemes
|
||||
assert security_schemes["apiKey"]["type"] == "apiKey"
|
||||
assert security_schemes["apiKey"]["in"] == "query"
|
||||
assert security_schemes["apiKey"]["name"] == "key"
|
||||
|
||||
def test_convert_schemas(self, prepared_converter):
|
||||
"""Test conversion of schema definitions."""
|
||||
# Call the method
|
||||
prepared_converter._convert_schemas()
|
||||
|
||||
# Verify the results
|
||||
schemas = prepared_converter.openapi_spec["components"]["schemas"]
|
||||
|
||||
# Check Calendar schema
|
||||
assert "Calendar" in schemas
|
||||
calendar_schema = schemas["Calendar"]
|
||||
assert calendar_schema["type"] == "object"
|
||||
assert calendar_schema["description"] == "A calendar resource"
|
||||
|
||||
# Check required properties
|
||||
assert "required" in calendar_schema
|
||||
assert "summary" in calendar_schema["required"]
|
||||
|
||||
# Check Event schema references
|
||||
assert "Event" in schemas
|
||||
event_schema = schemas["Event"]
|
||||
assert (
|
||||
event_schema["properties"]["start"]["$ref"]
|
||||
== "#/components/schemas/EventDateTime"
|
||||
)
|
||||
|
||||
# Check array type with references
|
||||
attendees_schema = event_schema["properties"]["attendees"]
|
||||
assert attendees_schema["type"] == "array"
|
||||
assert (
|
||||
attendees_schema["items"]["$ref"]
|
||||
== "#/components/schemas/EventAttendee"
|
||||
)
|
||||
|
||||
# Check enum values
|
||||
attendee_schema = schemas["EventAttendee"]
|
||||
response_status = attendee_schema["properties"]["responseStatus"]
|
||||
assert "enum" in response_status
|
||||
assert "accepted" in response_status["enum"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema_def, expected_type, expected_attrs",
|
||||
[
|
||||
# Test object type
|
||||
(
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Test object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "required": True},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"object",
|
||||
{"description": "Test object", "required": ["id"]},
|
||||
),
|
||||
# Test array type
|
||||
(
|
||||
{
|
||||
"type": "array",
|
||||
"description": "Test array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"array",
|
||||
{"description": "Test array", "items": {"type": "string"}},
|
||||
),
|
||||
# Test reference conversion
|
||||
(
|
||||
{"$ref": "Calendar"},
|
||||
None, # No type for references
|
||||
{"$ref": "#/components/schemas/Calendar"},
|
||||
),
|
||||
# Test enum conversion
|
||||
(
|
||||
{"type": "string", "enum": ["value1", "value2"]},
|
||||
"string",
|
||||
{"enum": ["value1", "value2"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_schema_object(
|
||||
self, converter, schema_def, expected_type, expected_attrs
|
||||
):
|
||||
"""Test conversion of individual schema objects with different input variations."""
|
||||
converted = converter._convert_schema_object(schema_def)
|
||||
|
||||
# Check type if expected
|
||||
if expected_type:
|
||||
assert converted["type"] == expected_type
|
||||
|
||||
# Check other expected attributes
|
||||
for key, value in expected_attrs.items():
|
||||
assert converted[key] == value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path, expected_params",
|
||||
[
|
||||
# Path with parameters
|
||||
(
|
||||
"/calendars/{calendarId}/events/{eventId}",
|
||||
["calendarId", "eventId"],
|
||||
),
|
||||
# Path without parameters
|
||||
("/calendars/events", []),
|
||||
# Mixed path
|
||||
("/users/{userId}/calendars/default", ["userId"]),
|
||||
],
|
||||
)
|
||||
def test_extract_path_parameters(self, converter, path, expected_params):
|
||||
"""Test extraction of path parameters from URL path with various inputs."""
|
||||
params = converter._extract_path_parameters(path)
|
||||
assert set(params) == set(expected_params)
|
||||
assert len(params) == len(expected_params)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_data, expected_result",
|
||||
[
|
||||
# String parameter
|
||||
(
|
||||
{
|
||||
"type": "string",
|
||||
"description": "String parameter",
|
||||
"pattern": "^[a-z]+$",
|
||||
},
|
||||
{"type": "string", "pattern": "^[a-z]+$"},
|
||||
),
|
||||
# Integer parameter with format
|
||||
(
|
||||
{"type": "integer", "format": "int32", "default": "10"},
|
||||
{"type": "integer", "format": "int32", "default": "10"},
|
||||
),
|
||||
# Enum parameter
|
||||
(
|
||||
{"type": "string", "enum": ["option1", "option2"]},
|
||||
{"type": "string", "enum": ["option1", "option2"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_parameter_schema(
|
||||
self, converter, param_data, expected_result
|
||||
):
|
||||
"""Test conversion of parameter definitions to OpenAPI schemas."""
|
||||
converted = converter._convert_parameter_schema(param_data)
|
||||
|
||||
# Check all expected attributes
|
||||
for key, value in expected_result.items():
|
||||
assert converted[key] == value
|
||||
|
||||
def test_convert(self, converter_with_patched_build):
|
||||
"""Test the complete conversion process."""
|
||||
# Call the method
|
||||
result = converter_with_patched_build.convert()
|
||||
|
||||
# Verify basic structure
|
||||
assert result["openapi"] == "3.0.0"
|
||||
assert "info" in result
|
||||
assert "servers" in result
|
||||
assert "paths" in result
|
||||
assert "components" in result
|
||||
|
||||
# Verify paths
|
||||
paths = result["paths"]
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
assert "get" in paths["/calendars/{calendarId}"]
|
||||
|
||||
# Verify nested resources
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
|
||||
# Verify method details
|
||||
get_calendar = paths["/calendars/{calendarId}"]["get"]
|
||||
assert get_calendar["operationId"] == "calendar.calendars.get"
|
||||
assert "parameters" in get_calendar
|
||||
|
||||
# Verify request body
|
||||
insert_calendar = paths["/calendars"]["post"]
|
||||
assert "requestBody" in insert_calendar
|
||||
request_schema = insert_calendar["requestBody"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert request_schema["$ref"] == "#/components/schemas/Calendar"
|
||||
|
||||
# Verify response body
|
||||
assert "responses" in get_calendar
|
||||
response_schema = get_calendar["responses"]["200"]["content"][
|
||||
"application/json"
|
||||
]["schema"]
|
||||
assert response_schema["$ref"] == "#/components/schemas/Calendar"
|
||||
|
||||
def test_convert_methods(self, prepared_converter, calendar_api_spec):
|
||||
"""Test conversion of API methods."""
|
||||
# Convert methods
|
||||
methods = calendar_api_spec["resources"]["calendars"]["methods"]
|
||||
prepared_converter._convert_methods(methods, "/calendars")
|
||||
|
||||
# Verify the results
|
||||
paths = prepared_converter.openapi_spec["paths"]
|
||||
|
||||
# Check GET method
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
get_method = paths["/calendars/{calendarId}"]["get"]
|
||||
assert get_method["operationId"] == "calendar.calendars.get"
|
||||
|
||||
# Check parameters
|
||||
params = get_method["parameters"]
|
||||
param_names = [p["name"] for p in params]
|
||||
assert "calendarId" in param_names
|
||||
|
||||
# Check POST method
|
||||
assert "/calendars" in paths
|
||||
post_method = paths["/calendars"]["post"]
|
||||
assert post_method["operationId"] == "calendar.calendars.insert"
|
||||
|
||||
# Check request body
|
||||
assert "requestBody" in post_method
|
||||
assert (
|
||||
post_method["requestBody"]["content"]["application/json"]["schema"][
|
||||
"$ref"
|
||||
]
|
||||
== "#/components/schemas/Calendar"
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert (
|
||||
post_method["responses"]["200"]["content"]["application/json"][
|
||||
"schema"
|
||||
]["$ref"]
|
||||
== "#/components/schemas/Calendar"
|
||||
)
|
||||
|
||||
def test_convert_resources(self, prepared_converter, calendar_api_spec):
|
||||
"""Test conversion of nested resources."""
|
||||
# Convert resources
|
||||
resources = calendar_api_spec["resources"]
|
||||
prepared_converter._convert_resources(resources)
|
||||
|
||||
# Verify the results
|
||||
paths = prepared_converter.openapi_spec["paths"]
|
||||
|
||||
# Check top-level resource methods
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
|
||||
# Check nested resource methods
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
events_method = paths["/calendars/{calendarId}/events"]["get"]
|
||||
assert events_method["operationId"] == "calendar.events.list"
|
||||
|
||||
# Check parameters in nested resource
|
||||
params = events_method["parameters"]
|
||||
param_names = [p["name"] for p in params]
|
||||
assert "calendarId" in param_names
|
||||
assert "maxResults" in param_names
|
||||
assert "orderBy" in param_names
|
||||
|
||||
def test_integration_calendar_api(self, converter_with_patched_build):
|
||||
"""Integration test using Calendar API specification."""
|
||||
# Create and run the converter
|
||||
openapi_spec = converter_with_patched_build.convert()
|
||||
|
||||
# Verify conversion results
|
||||
assert openapi_spec["info"]["title"] == "Google Calendar API"
|
||||
assert (
|
||||
openapi_spec["servers"][0]["url"]
|
||||
== "https://www.googleapis.com/calendar/v3"
|
||||
)
|
||||
|
||||
# Check security schemes
|
||||
security_schemes = openapi_spec["components"]["securitySchemes"]
|
||||
assert "oauth2" in security_schemes
|
||||
assert "apiKey" in security_schemes
|
||||
|
||||
# Check schemas
|
||||
schemas = openapi_spec["components"]["schemas"]
|
||||
assert "Calendar" in schemas
|
||||
assert "Event" in schemas
|
||||
assert "EventDateTime" in schemas
|
||||
|
||||
# Check paths
|
||||
paths = openapi_spec["paths"]
|
||||
assert "/calendars/{calendarId}" in paths
|
||||
assert "/calendars" in paths
|
||||
assert "/calendars/{calendarId}/events" in paths
|
||||
|
||||
# Check method details
|
||||
get_events = paths["/calendars/{calendarId}/events"]["get"]
|
||||
assert get_events["operationId"] == "calendar.events.list"
|
||||
|
||||
# Check parameter details
|
||||
param_dict = {p["name"]: p for p in get_events["parameters"]}
|
||||
assert "maxResults" in param_dict
|
||||
max_results = param_dict["maxResults"]
|
||||
assert max_results["in"] == "query"
|
||||
assert max_results["schema"]["type"] == "integer"
|
||||
assert max_results["schema"]["default"] == "250"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conftest_content():
|
||||
"""Returns content for a conftest.py file to help with testing."""
|
||||
return """
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# This file contains fixtures that can be shared across multiple test modules
|
||||
|
||||
@pytest.fixture
|
||||
def mock_google_response():
|
||||
\"\"\"Fixture that provides a mock response from Google's API.\"\"\"
|
||||
return {"key": "value", "items": [{"id": 1}, {"id": 2}]}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_error():
|
||||
\"\"\"Fixture that provides a mock HTTP error.\"\"\"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = 404
|
||||
return HttpError(resp=mock_resp, content=b'Not Found')
|
||||
"""
|
||||
|
||||
|
||||
def test_generate_conftest_example(conftest_content):
|
||||
"""This is a meta-test that demonstrates how to generate a conftest.py file.
|
||||
|
||||
In a real project, you would create a separate conftest.py file.
|
||||
"""
|
||||
# In a real scenario, you would write this to a file named conftest.py
|
||||
# This test just verifies the conftest content is not empty
|
||||
assert len(conftest_content) > 0
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for AutoAuthCredentialExchanger."""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
"""Mock credential exchanger for testing."""
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Mock exchange credential method."""
|
||||
return auth_credential
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_exchanger():
|
||||
"""Fixture for creating an AutoAuthCredentialExchanger instance."""
|
||||
return AutoAuthCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
"""Fixture for creating a mock AuthScheme instance."""
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
return scheme
|
||||
|
||||
|
||||
def test_init_with_custom_exchangers():
|
||||
"""Test initialization with custom exchangers."""
|
||||
custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
|
||||
AuthCredentialTypes.API_KEY: MockCredentialExchanger
|
||||
}
|
||||
|
||||
auto_exchanger = AutoAuthCredentialExchanger(
|
||||
custom_exchangers=custom_exchangers
|
||||
)
|
||||
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
|
||||
== MockCredentialExchanger
|
||||
)
|
||||
assert (
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
|
||||
== OAuth2CredentialExchanger
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with no auth_credential."""
|
||||
|
||||
assert auto_exchanger.exchange_credential(auth_scheme, None) is None
|
||||
|
||||
|
||||
def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with NoExchangeCredentialExchanger."""
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == auth_credential
|
||||
|
||||
|
||||
def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with OpenID Connect scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
|
||||
"""Test exchange_credential with Service Account scheme."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
)
|
||||
mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "exchanged_credential_sa"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
|
||||
"""Test that exchange_credential calls the correct (custom) exchanger."""
|
||||
# Use a custom exchanger via the initialization
|
||||
mock_exchanger = MagicMock(spec=MockCredentialExchanger)
|
||||
mock_exchanger.exchange_credential.return_value = "custom_credential"
|
||||
auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
|
||||
lambda: mock_exchanger
|
||||
)
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
|
||||
result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
|
||||
assert result == "custom_credential"
|
||||
mock_exchanger.exchange_credential.assert_called_once_with(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
import pytest
|
||||
|
||||
|
||||
class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
return AuthCredential(token="some-token")
|
||||
|
||||
|
||||
class TestBaseAuthCredentialExchanger:
|
||||
"""Tests for the BaseAuthCredentialExchanger class."""
|
||||
|
||||
@pytest.fixture
|
||||
def base_exchanger(self):
|
||||
return BaseAuthCredentialExchanger()
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme(self):
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type = "apiKey"
|
||||
scheme.name = "x-api-key"
|
||||
return scheme
|
||||
|
||||
def test_exchange_credential_not_implemented(
|
||||
self, base_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, token="some-token"
|
||||
)
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
base_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Subclasses must implement exchange_credential." in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_auth_credential_missing_error(self):
|
||||
error_message = "Test missing credential"
|
||||
error = AuthCredentialMissingError(error_message)
|
||||
# assert error.message == error_message
|
||||
assert str(error) == error_message
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for OAuth2CredentialExchanger."""
|
||||
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_exchanger():
|
||||
return OAuth2CredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
openid_config = OpenIdConnectWithConfig(
|
||||
type_=AuthSchemeType.openIdConnect,
|
||||
authorization_endpoint="https://example.com/auth",
|
||||
token_endpoint="https://example.com/token",
|
||||
scopes=["openid", "profile"],
|
||||
)
|
||||
return openid_config
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
# Check that the method does not raise an exception
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_credential(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
# Test case: auth_credential is None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
|
||||
assert "auth_credential is empty" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_invalid_scheme_type(
|
||||
oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
|
||||
):
|
||||
"""Test case: Invalid AuthSchemeType."""
|
||||
# Test case: Invalid AuthSchemeType
|
||||
invalid_scheme = copy.deepcopy(auth_scheme)
|
||||
invalid_scheme.type_ = AuthSchemeType.apiKey
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(
|
||||
invalid_scheme, auth_credential
|
||||
)
|
||||
assert "Invalid security scheme" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_check_scheme_credential_type_missing_openid_connect(
|
||||
oauth2_exchanger, auth_scheme
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
|
||||
assert "auth_credential is not configured with oauth2" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_generate_auth_token_success(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test case: Successful generation of access token."""
|
||||
# Test case: Successful generation of access token
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_generate_auth_token(
|
||||
oauth2_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test exchange_credential when auth_response_uri is present."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
redirect_uri="http://localhost:8080",
|
||||
auth_response_uri="https://example.com/callback?code=test_code",
|
||||
token={"access_token": "test_access_token"},
|
||||
),
|
||||
)
|
||||
|
||||
updated_credential = oauth2_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert updated_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert updated_credential.http.scheme == "bearer"
|
||||
assert updated_credential.http.credentials.token == "test_access_token"
|
||||
|
||||
|
||||
def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
|
||||
"""Test exchange_credential when auth_credential is missing."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth2_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "auth_credential is empty. Please create AuthCredential using" in str(
|
||||
exc_info.value
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the service account credential exchanger."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import google.auth
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_exchanger():
|
||||
return ServiceAccountCredentialExchanger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_scheme():
|
||||
scheme = MagicMock(spec=AuthScheme)
|
||||
scheme.type_ = AuthSchemeType.oauth2
|
||||
scheme.description = "Google Service Account"
|
||||
return scheme
|
||||
|
||||
|
||||
def test_exchange_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
|
||||
# Mock the from_service_account_info method
|
||||
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
# Mock the refresh method
|
||||
mock_credentials.refresh = MagicMock()
|
||||
|
||||
# Create a valid AuthCredential with service account info
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_use_default_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials using default credential."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_google_auth_default = MagicMock(
|
||||
return_value=(mock_credentials, "test_project")
|
||||
)
|
||||
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_google_auth_default.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_missing_auth_credential(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing auth credential during exchange."""
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_missing_service_account_info(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing service account info during exchange."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_exchange_failure(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
"""Test failure during service account token exchange."""
|
||||
mock_from_service_account_info = MagicMock(
|
||||
side_effect=Exception("Failed to load credentials")
|
||||
)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
),
|
||||
)
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account token" in str(exc_info.value)
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
573
tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
from fastapi.openapi.models import HTTPBearer
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OpenIdConnect
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_credential import ServiceAccountCredential
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_header():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "X-API-Key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_query():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "query", "api_key", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.query
|
||||
assert scheme.name == "api_key"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_cookie():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id", "test_key"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.cookie
|
||||
assert scheme.name == "session_id"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_api_key_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"apikey", "cookie", "session_id"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_token():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization", "test_token"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential == AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_token_to_scheme_credential_oauth2_no_credential():
|
||||
scheme, credential = token_to_scheme_credential(
|
||||
"oauth2Token", "header", "Authorization"
|
||||
)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert credential is None
|
||||
|
||||
|
||||
def test_service_account_dict_to_scheme_credential():
|
||||
config = {
|
||||
"type": "service_account",
|
||||
"project_id": "project_id",
|
||||
"private_key_id": "private_key_id",
|
||||
"private_key": "private_key",
|
||||
"client_email": "client_email",
|
||||
"client_id": "client_id",
|
||||
"auth_uri": "auth_uri",
|
||||
"token_uri": "token_uri",
|
||||
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
|
||||
"client_x509_cert_url": "client_x509_cert_url",
|
||||
"universe_domain": "universe_domain",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account.scopes == scopes
|
||||
assert (
|
||||
credential.service_account.service_account_credential.project_id
|
||||
== "project_id"
|
||||
)
|
||||
|
||||
|
||||
def test_service_account_scheme_credential():
|
||||
config = ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type="service_account",
|
||||
project_id="project_id",
|
||||
private_key_id="private_key_id",
|
||||
private_key="private_key",
|
||||
client_email="client_email",
|
||||
client_id="client_id",
|
||||
auth_uri="auth_uri",
|
||||
token_uri="token_uri",
|
||||
auth_provider_x509_cert_url="auth_provider_x509_cert_url",
|
||||
client_x509_cert_url="client_x509_cert_url",
|
||||
universe_domain="universe_domain",
|
||||
),
|
||||
scopes=["scope1", "scope2"],
|
||||
)
|
||||
|
||||
scheme, credential = service_account_scheme_credential(config)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
|
||||
assert credential.service_account == config
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_no_openid_url():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == ""
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_google_oauth_credential():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"openIdConnectUrl": "openid_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"web": {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_dict_to_scheme_credential(
|
||||
config_dict, scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_invalid_config():
|
||||
config_dict = {
|
||||
"invalid_field": "value",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
def test_openid_dict_to_scheme_credential_missing_credential_fields():
|
||||
config_dict = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
}
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Missing required fields in credential_dict: client_secret",
|
||||
):
|
||||
openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnectWithConfig)
|
||||
assert scheme.authorization_endpoint == "auth_url"
|
||||
assert scheme.token_endpoint == "token_url"
|
||||
assert scheme.scopes == scopes
|
||||
assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
|
||||
assert credential.oauth2.client_id == "client_id"
|
||||
assert credential.oauth2.client_secret == "client_secret"
|
||||
assert credential.oauth2.redirect_uri == "redirect_uri"
|
||||
mock_get.assert_called_once_with("openid_url", timeout=10)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
|
||||
mock_response = {
|
||||
"authorization_endpoint": "auth_url",
|
||||
"token_endpoint": "token_url",
|
||||
"userinfo_endpoint": "userinfo_url",
|
||||
}
|
||||
mock_get.return_value.json.return_value = mock_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {
|
||||
"client_id": "client_id",
|
||||
"client_secret": "client_secret",
|
||||
"redirect_uri": "redirect_uri",
|
||||
}
|
||||
scopes = ["scope1", "scope2"]
|
||||
|
||||
scheme, credential = openid_url_to_scheme_credential(
|
||||
"openid_url", scopes, credential_dict
|
||||
)
|
||||
|
||||
assert scheme.openIdConnectUrl == "openid_url"
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_request_exception(mock_get):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Test Error")
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to fetch OpenID configuration from openid_url"
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
@patch("requests.get")
|
||||
def test_openid_url_to_scheme_credential_invalid_json(mock_get):
|
||||
mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Invalid JSON response from OpenID configuration endpoint openid_url"
|
||||
),
|
||||
):
|
||||
openid_url_to_scheme_credential("openid_url", [], credential_dict)
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_header():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "X-API-Key"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_query():
|
||||
auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "api_key"
|
||||
assert param.param_location == "query"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_api_key_cookie():
|
||||
auth_scheme = APIKey(
|
||||
**{"type": "apiKey", "in": "cookie", "name": "session_id"}
|
||||
)
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "session_id"
|
||||
assert param.param_location == "cookie"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_bearer():
|
||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_http_basic_not_supported():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="password"),
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Basic Authentication is not supported."
|
||||
):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_http_invalid_credentials_no_http():
|
||||
auth_scheme = HTTPBase(scheme="basic")
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
|
||||
credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_connect():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="test_token")
|
||||
),
|
||||
)
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, auth_credential)
|
||||
|
||||
assert param.original_name == "Authorization"
|
||||
assert param.param_location == "header"
|
||||
assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_credential_to_param_openid_no_credential():
|
||||
auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_credential_to_param_oauth2_no_credential():
|
||||
auth_scheme = OAuth2(flows={})
|
||||
|
||||
param, kwargs = credential_to_param(auth_scheme, None)
|
||||
|
||||
assert param == None
|
||||
assert kwargs == None
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_api_key():
|
||||
data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, APIKey)
|
||||
assert scheme.type_ == AuthSchemeType.apiKey
|
||||
assert scheme.in_ == APIKeyIn.header
|
||||
assert scheme.name == "X-API-Key"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_bearer():
|
||||
data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBearer)
|
||||
assert scheme.scheme == "bearer"
|
||||
assert scheme.bearerFormat == "JWT"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_http_base():
|
||||
data = {"type": "http", "scheme": "basic"}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, HTTPBase)
|
||||
assert scheme.scheme == "basic"
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_oauth2():
|
||||
data = {
|
||||
"type": "oauth2",
|
||||
"flows": {
|
||||
"authorizationCode": {
|
||||
"authorizationUrl": "https://example.com/auth",
|
||||
"tokenUrl": "https://example.com/token",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OAuth2)
|
||||
assert hasattr(scheme.flows, "authorizationCode")
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_openid_connect():
|
||||
data = {
|
||||
"type": "openIdConnect",
|
||||
"openIdConnectUrl": (
|
||||
"https://example.com/.well-known/openid-configuration"
|
||||
),
|
||||
}
|
||||
|
||||
scheme = dict_to_auth_scheme(data)
|
||||
|
||||
assert isinstance(scheme, OpenIdConnect)
|
||||
assert (
|
||||
scheme.openIdConnectUrl
|
||||
== "https://example.com/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_missing_type():
|
||||
data = {"in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(
|
||||
ValueError, match="Missing 'type' field in security scheme dictionary."
|
||||
):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_type():
|
||||
data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
|
||||
with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
def test_dict_to_auth_scheme_invalid_data():
|
||||
data = {"type": "apiKey", "in": "header"} # Missing 'name'
|
||||
with pytest.raises(ValueError, match="Invalid security scheme data"):
|
||||
dict_to_auth_scheme(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
436
tests/unittests/tools/openapi_tool/common/test_common.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from fastapi.openapi.models import Response, Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.common.common import PydocHelper
|
||||
from google.adk.tools.openapi_tool.common.common import rename_python_keywords
|
||||
from google.adk.tools.openapi_tool.common.common import to_snake_case
|
||||
from google.adk.tools.openapi_tool.common.common import TypeHintHelper
|
||||
import pytest
|
||||
|
||||
|
||||
def dict_to_responses(input: Dict[str, Any]) -> Dict[str, Response]:
|
||||
return {k: Response.model_validate(input[k]) for k in input}
|
||||
|
||||
|
||||
class TestToSnakeCase:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('lowerCamelCase', 'lower_camel_case'),
|
||||
('UpperCamelCase', 'upper_camel_case'),
|
||||
('space separated', 'space_separated'),
|
||||
('REST API', 'rest_api'),
|
||||
('Mixed_CASE with_Spaces', 'mixed_case_with_spaces'),
|
||||
('__init__', 'init'),
|
||||
('APIKey', 'api_key'),
|
||||
('SomeLongURL', 'some_long_url'),
|
||||
('CONSTANT_CASE', 'constant_case'),
|
||||
('already_snake_case', 'already_snake_case'),
|
||||
('single', 'single'),
|
||||
('', ''),
|
||||
(' spaced ', 'spaced'),
|
||||
('with123numbers', 'with123numbers'),
|
||||
('With_Mixed_123_and_SPACES', 'with_mixed_123_and_spaces'),
|
||||
('HTMLParser', 'html_parser'),
|
||||
('HTTPResponseCode', 'http_response_code'),
|
||||
('a_b_c', 'a_b_c'),
|
||||
('A_B_C', 'a_b_c'),
|
||||
('fromAtoB', 'from_ato_b'),
|
||||
('XMLHTTPRequest', 'xmlhttp_request'),
|
||||
('_leading', 'leading'),
|
||||
('trailing_', 'trailing'),
|
||||
(' leading_and_trailing_ ', 'leading_and_trailing'),
|
||||
('Multiple___Underscores', 'multiple_underscores'),
|
||||
(' spaces_and___underscores ', 'spaces_and_underscores'),
|
||||
(' _mixed_Case ', 'mixed_case'),
|
||||
('123Start', '123_start'),
|
||||
('End123', 'end123'),
|
||||
('Mid123dle', 'mid123dle'),
|
||||
],
|
||||
)
|
||||
def test_to_snake_case(self, input_str, expected_output):
|
||||
assert to_snake_case(input_str) == expected_output
|
||||
|
||||
|
||||
class TestRenamePythonKeywords:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_str, expected_output',
|
||||
[
|
||||
('in', 'param_in'),
|
||||
('for', 'param_for'),
|
||||
('class', 'param_class'),
|
||||
('normal', 'normal'),
|
||||
('param_if', 'param_if'),
|
||||
('', ''),
|
||||
],
|
||||
)
|
||||
def test_rename_python_keywords(self, input_str, expected_output):
|
||||
assert rename_python_keywords(input_str) == expected_output
|
||||
|
||||
|
||||
class TestApiParameter:
|
||||
|
||||
def test_api_parameter_initialization(self):
|
||||
schema = Schema(type='string', description='A string parameter')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
description='A string description',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.original_name == 'testParam'
|
||||
assert param.param_location == 'query'
|
||||
assert param.param_schema.type == 'string'
|
||||
assert param.param_schema.description == 'A string parameter'
|
||||
assert param.py_name == 'test_param'
|
||||
assert param.type_hint == 'str'
|
||||
assert param.type_value == str
|
||||
assert param.description == 'A string description'
|
||||
|
||||
def test_api_parameter_keyword_rename(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='in',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.py_name == 'param_in'
|
||||
|
||||
def test_api_parameter_custom_py_name(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
py_name='custom_name',
|
||||
)
|
||||
assert param.py_name == 'custom_name'
|
||||
|
||||
def test_api_parameter_str_representation(self):
|
||||
schema = Schema(type='number')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert str(param) == 'test_param: float'
|
||||
|
||||
def test_api_parameter_to_arg_string(self):
|
||||
schema = Schema(type='boolean')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_arg_string() == 'test_param=test_param'
|
||||
|
||||
def test_api_parameter_to_dict_property(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='testParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.to_dict_property() == '"test_param": test_param'
|
||||
|
||||
def test_api_parameter_model_serializer(self):
|
||||
schema = Schema(type='string', description='test description')
|
||||
param = ApiParameter(
|
||||
original_name='TestParam',
|
||||
param_location='path',
|
||||
param_schema=schema,
|
||||
py_name='test_param_custom',
|
||||
description='test description',
|
||||
)
|
||||
|
||||
serialized_param = param.model_dump(mode='json', exclude_none=True)
|
||||
|
||||
assert serialized_param == {
|
||||
'original_name': 'TestParam',
|
||||
'param_location': 'path',
|
||||
'param_schema': {'type': 'string', 'description': 'test description'},
|
||||
'description': 'test description',
|
||||
'py_name': 'test_param_custom',
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'boolean'}, bool, 'bool'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{'type': 'string', 'format': 'date'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'string', 'format': 'date-time'},
|
||||
str,
|
||||
'str',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'integer'}},
|
||||
List[int],
|
||||
'List[int]',
|
||||
),
|
||||
(
|
||||
{'type': 'array', 'items': {'type': 'string'}},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'object'},
|
||||
},
|
||||
List[Dict[str, Any]],
|
||||
'List[Dict[str, Any]]',
|
||||
),
|
||||
({'type': 'object'}, Dict[str, Any], 'Dict[str, Any]'),
|
||||
({'type': 'unknown'}, Any, 'Any'),
|
||||
({}, Any, 'Any'),
|
||||
],
|
||||
)
|
||||
def test_api_parameter_type_hint_helper(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
param = ApiParameter(
|
||||
original_name='test', param_location='query', param_schema=schema
|
||||
)
|
||||
assert param.type_value == expected_type_value
|
||||
assert param.type_hint == expected_type_hint
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
|
||||
def test_api_parameter_description(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description',
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
def test_api_parameter_description_use_schema_fallback(self):
|
||||
schema = Schema(type='string', description='The description')
|
||||
param = ApiParameter(
|
||||
original_name='param1',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
assert param.description == 'The description'
|
||||
|
||||
|
||||
class TestTypeHintHelper:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'schema, expected_type_value, expected_type_hint',
|
||||
[
|
||||
({'type': 'integer'}, int, 'int'),
|
||||
({'type': 'number'}, float, 'float'),
|
||||
({'type': 'string'}, str, 'str'),
|
||||
(
|
||||
{
|
||||
'type': 'array',
|
||||
'items': {'type': 'string'},
|
||||
},
|
||||
List[str],
|
||||
'List[str]',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_type_value_and_hint(
|
||||
self, schema, expected_type_value, expected_type_hint
|
||||
):
|
||||
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test parameter',
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
|
||||
)
|
||||
assert (
|
||||
TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
|
||||
)
|
||||
|
||||
|
||||
class TestPydocHelper:
|
||||
|
||||
def test_generate_param_doc_simple(self):
|
||||
schema = Schema(type='string')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test description',
|
||||
)
|
||||
|
||||
expected_doc = 'test_param (str): Test description'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_no_description(self):
|
||||
schema = Schema(type='integer')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
)
|
||||
expected_doc = 'test_param (int): '
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object(self):
|
||||
schema = Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': {'type': 'string', 'description': 'Prop1 desc'},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
)
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='Test object parameter',
|
||||
)
|
||||
expected_doc = (
|
||||
'test_param (Dict[str, Any]): Test object parameter Object'
|
||||
' properties:\n prop1 (str): Prop1 desc\n prop2'
|
||||
' (int): \n'
|
||||
)
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_param_doc_object_no_properties(self):
|
||||
schema = Schema(type='object', description='A test schema')
|
||||
param = ApiParameter(
|
||||
original_name='test_param',
|
||||
param_location='query',
|
||||
param_schema=schema,
|
||||
description='The description.',
|
||||
)
|
||||
expected_doc = 'test_param (Dict[str, Any]): The description.'
|
||||
assert PydocHelper.generate_param_doc(param) == expected_doc
|
||||
|
||||
def test_generate_return_doc_simple(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_no_content(self):
|
||||
responses = {'204': {'description': 'No content'}}
|
||||
assert not PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
def test_generate_return_doc_object(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful object response',
|
||||
'content': {
|
||||
'application/json': {
|
||||
'schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'prop1': {
|
||||
'type': 'string',
|
||||
'description': 'Prop1 desc',
|
||||
},
|
||||
'prop2': {'type': 'integer'},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return_doc = PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
|
||||
assert 'Returns (Dict[str, Any]): Successful object response' in return_doc
|
||||
assert 'prop1 (str): Prop1 desc' in return_doc
|
||||
assert 'prop2 (int):' in return_doc
|
||||
|
||||
def test_generate_return_doc_multiple_success(self):
|
||||
responses = {
|
||||
'200': {
|
||||
'description': 'Successful response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): Successful response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_2xx_smallest_status_code_response(self):
|
||||
responses = {
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'integer'}}},
|
||||
},
|
||||
'200': {
|
||||
'description': '200 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
|
||||
expected_doc = 'Returns (str): 200 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
def test_generate_return_doc_contentful_response(self):
|
||||
responses = {
|
||||
'200': {'description': 'No content response'},
|
||||
'201': {
|
||||
'description': '201 response',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
},
|
||||
'400': {'description': 'Bad request'},
|
||||
}
|
||||
expected_doc = 'Returns (str): 201 response'
|
||||
assert (
|
||||
PydocHelper.generate_return_doc(dict_to_responses(responses))
|
||||
== expected_doc
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
1367
tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||
import pytest
|
||||
|
||||
|
||||
def create_minimal_openapi_spec() -> Dict[str, Any]:
|
||||
"""Creates a minimal valid OpenAPI spec."""
|
||||
return {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Minimal API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test": {
|
||||
"get": {
|
||||
"summary": "Test GET endpoint",
|
||||
"operationId": "testGet",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"application/json": {"schema": {"type": "string"}}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec_generator():
|
||||
"""Fixture for creating an OperationGenerator instance."""
|
||||
return OpenApiSpecParser()
|
||||
|
||||
|
||||
def test_parse_minimal_spec(openapi_spec_generator):
|
||||
"""Test parsing a minimal OpenAPI specification."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.name == "test_get"
|
||||
assert op.endpoint.path == "/test"
|
||||
assert op.endpoint.method == "get"
|
||||
assert op.return_value.type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_no_operation_id(openapi_spec_generator):
|
||||
"""Test parsing a spec where operationId is missing (auto-generation)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
# Check if operationId is auto generated based on path and method.
|
||||
assert parsed_operations[0].name == "test_get"
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_methods(openapi_spec_generator):
|
||||
"""Test parsing a spec with multiple HTTP methods for the same path."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Test POST endpoint",
|
||||
"operationId": "testPost",
|
||||
"responses": {"200": {"description": "Successful response"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
operation_names = {op.name for op in parsed_operations}
|
||||
|
||||
assert len(parsed_operations) == 2
|
||||
assert "test_get" in operation_names
|
||||
assert "test_post" in operation_names
|
||||
|
||||
|
||||
def test_parse_spec_with_parameters(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["parameters"] = [
|
||||
{"name": "param1", "in": "query", "schema": {"type": "string"}},
|
||||
{"name": "param2", "in": "header", "schema": {"type": "integer"}},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations[0].parameters) == 2
|
||||
assert parsed_operations[0].parameters[0].original_name == "param1"
|
||||
assert parsed_operations[0].parameters[0].param_location == "query"
|
||||
assert parsed_operations[0].parameters[1].original_name == "param2"
|
||||
assert parsed_operations[0].parameters[1].param_location == "header"
|
||||
|
||||
|
||||
def test_parse_spec_with_request_body(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["post"] = {
|
||||
"summary": "Endpoint with request body",
|
||||
"operationId": "testPostWithBody",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
post_operations = [
|
||||
op for op in parsed_operations if op.endpoint.method == "post"
|
||||
]
|
||||
op = post_operations[0]
|
||||
|
||||
assert len(post_operations) == 1
|
||||
assert op.name == "test_post_with_body"
|
||||
assert len(op.parameters) == 1
|
||||
assert op.parameters[0].original_name == "name"
|
||||
assert op.parameters[0].type_value == str
|
||||
|
||||
|
||||
def test_parse_spec_with_reference(openapi_spec_generator):
|
||||
"""Test parsing a specification with $ref."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "API with Refs", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/test_ref": {
|
||||
"get": {
|
||||
"summary": "Endpoint with ref",
|
||||
"operationId": "testGetRef",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Success",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/MySchema"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"MySchema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
|
||||
|
||||
def test_parse_spec_with_circular_reference(openapi_spec_generator):
|
||||
"""Test correct handling of circular $ref (important!)."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Circular Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/circular": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/A"}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"A": {
|
||||
"type": "object",
|
||||
"properties": {"b": {"$ref": "#/components/schemas/B"}},
|
||||
},
|
||||
"B": {
|
||||
"type": "object",
|
||||
"properties": {"a": {"$ref": "#/components/schemas/A"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
|
||||
op = parsed_operations[0]
|
||||
assert op.return_value.type_value.__origin__ is dict
|
||||
assert op.return_value.type_hint == "Dict[str, Any]"
|
||||
|
||||
|
||||
def test_parse_no_paths(openapi_spec_generator):
|
||||
"""Test with a spec that has no paths defined."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "No Paths API", "version": "1.0.0"},
|
||||
}
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_parse_empty_path_item(openapi_spec_generator):
|
||||
"""Test a path item that is present but empty."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Empty Path Item API", "version": "1.0.0"},
|
||||
"paths": {"/empty": None},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 0
|
||||
|
||||
|
||||
def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a global security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["security"] = [{"api_key": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {
|
||||
"api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
|
||||
}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "apiKey"
|
||||
|
||||
|
||||
def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
|
||||
"""Test parsing with a local (operation-level) security scheme."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
|
||||
openapi_spec["components"] = {
|
||||
"securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
op = parsed_operations[0]
|
||||
|
||||
assert op.auth_scheme is not None
|
||||
assert op.auth_scheme.type_.value == "http"
|
||||
assert op.auth_scheme.scheme == "bearer"
|
||||
|
||||
|
||||
def test_parse_spec_with_servers(openapi_spec_generator):
|
||||
"""Test parsing with server URLs."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["servers"] = [
|
||||
{"url": "https://api.example.com"},
|
||||
{"url": "http://localhost:8000"},
|
||||
]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
|
||||
|
||||
|
||||
def test_parse_spec_with_no_servers(openapi_spec_generator):
|
||||
"""Test with no servers defined (should default to empty string)."""
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
if "servers" in openapi_spec:
|
||||
del openapi_spec["servers"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].endpoint.base_url == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
expected_description = "This is a test description."
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == expected_description
|
||||
|
||||
|
||||
def test_parse_spec_with_empty_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
openapi_spec["paths"]["/test"]["get"]["description"] = ""
|
||||
openapi_spec["paths"]["/test"]["get"]["summary"] = ""
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert parsed_operations[0].description == ""
|
||||
|
||||
|
||||
def test_parse_spec_with_no_description(openapi_spec_generator):
|
||||
openapi_spec = create_minimal_openapi_spec()
|
||||
|
||||
# delete description
|
||||
if "description" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["description"]
|
||||
if "summary" in openapi_spec["paths"]["/test"]["get"]:
|
||||
del openapi_spec["paths"]["/test"]["get"]["summary"]
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
assert len(parsed_operations) == 1
|
||||
assert (
|
||||
parsed_operations[0].description == ""
|
||||
) # it should be initialized with empty string
|
||||
|
||||
|
||||
def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
|
||||
"""Test that passing a non-dict object to parse raises TypeError"""
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse(123) # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse("openapi_spec") # type: ignore
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
openapi_spec_generator.parse([]) # type: ignore
|
||||
|
||||
|
||||
def test_parse_external_ref_raises_error(openapi_spec_generator):
|
||||
"""Check that external references (not starting with #) raise ValueError."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "External Ref API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/external": {
|
||||
"get": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": (
|
||||
"external_file.json#/components/schemas/ExternalSchema"
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
openapi_spec_generator.parse(openapi_spec)
|
||||
|
||||
|
||||
def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
|
||||
"""Test specs with multiple paths, request/response bodies using deep refs."""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/path1": {
|
||||
"post": {
|
||||
"operationId": "postPath1",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request1"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response1"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/path2": {
|
||||
"put": {
|
||||
"operationId": "putPath2",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Request2"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"get": {
|
||||
"operationId": "getPath2",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Response2"
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Request1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Request2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
|
||||
},
|
||||
},
|
||||
"Response2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
|
||||
},
|
||||
},
|
||||
"Level1_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_1_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_1"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level1_2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level1_2_prop1": {
|
||||
"$ref": "#/components/schemas/Level2_2"
|
||||
}
|
||||
},
|
||||
},
|
||||
"Level2_1": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
|
||||
},
|
||||
},
|
||||
"Level2_2": {
|
||||
"type": "object",
|
||||
"properties": {"level2_2_prop1": {"type": "string"}},
|
||||
},
|
||||
"Level3": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 3
|
||||
|
||||
# Verify Path 1
|
||||
path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
|
||||
assert len(path1_ops) == 1
|
||||
path1_op = path1_ops[0]
|
||||
assert path1_op.name == "post_path1"
|
||||
|
||||
assert len(path1_op.parameters) == 1
|
||||
assert path1_op.parameters[0].original_name == "req1_prop1"
|
||||
assert (
|
||||
path1_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path1_op.return_value.param_schema.properties["res1_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
# Verify Path 2
|
||||
path2_ops = [
|
||||
op
|
||||
for op in parsed_operations
|
||||
if op.endpoint.path == "/path2" and op.name == "put_path2"
|
||||
]
|
||||
path2_op = path2_ops[0]
|
||||
assert path2_op is not None
|
||||
assert len(path2_op.parameters) == 1
|
||||
assert path2_op.parameters[0].original_name == "req2_prop1"
|
||||
assert (
|
||||
path2_op.parameters[0]
|
||||
.param_schema.properties["level1_1_prop1"]
|
||||
.properties["level2_1_prop1"]
|
||||
.type
|
||||
== "integer"
|
||||
)
|
||||
assert (
|
||||
path2_op.return_value.param_schema.properties["res2_prop1"]
|
||||
.properties["level1_2_prop1"]
|
||||
.properties["level2_2_prop1"]
|
||||
.type
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
|
||||
"""Test handling of duplicate parameter names (one in query, one in body).
|
||||
|
||||
The expected behavior is that both parameters should be captured but with
|
||||
different suffix, and
|
||||
their `original_name` attributes should reflect their origin (query or body).
|
||||
"""
|
||||
openapi_spec = {
|
||||
"openapi": "3.1.0",
|
||||
"info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/duplicate": {
|
||||
"post": {
|
||||
"operationId": "createWithDuplicate",
|
||||
"parameters": [{
|
||||
"name": "name",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
}],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "integer"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
parsed_operations = openapi_spec_generator.parse(openapi_spec)
|
||||
assert len(parsed_operations) == 1
|
||||
op = parsed_operations[0]
|
||||
assert op.name == "create_with_duplicate"
|
||||
assert len(op.parameters) == 2
|
||||
|
||||
query_param = None
|
||||
body_param = None
|
||||
for param in op.parameters:
|
||||
if param.param_location == "query" and param.original_name == "name":
|
||||
query_param = param
|
||||
elif param.param_location == "body" and param.original_name == "name":
|
||||
body_param = param
|
||||
|
||||
assert query_param is not None
|
||||
assert query_param.original_name == "name"
|
||||
assert query_param.py_name == "name"
|
||||
|
||||
assert body_param is not None
|
||||
assert body_param.original_name == "name"
|
||||
assert body_param.py_name == "name_0"
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import ParameterInType
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
def load_spec(file_path: str) -> Dict:
|
||||
"""Loads the OpenAPI specification from a YAML file."""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_spec() -> Dict:
|
||||
"""Fixture to load the OpenAPI specification."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Join the directory path with the filename
|
||||
yaml_path = os.path.join(current_dir, "test.yaml")
|
||||
return load_spec(yaml_path)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a dictionary."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
|
||||
"""Test initialization of OpenAPIToolset with a YAML string."""
|
||||
spec_str = yaml.dump(openapi_spec)
|
||||
toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
|
||||
assert isinstance(toolset.tools, list)
|
||||
assert len(toolset.tools) == 5
|
||||
assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for an existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool_name = "calendar_calendars_insert" # Example operationId from the spec
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Creates a secondary calendar."
|
||||
assert tool.endpoint.method == "post"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.insert"
|
||||
assert tool.operation.description == "Creates a secondary calendar."
|
||||
assert isinstance(
|
||||
tool.operation.requestBody.content["application/json"], MediaType
|
||||
)
|
||||
assert len(tool.operation.responses) == 1
|
||||
response = tool.operation.responses["200"]
|
||||
assert response.description == "Successful response"
|
||||
assert isinstance(response.content["application/json"], MediaType)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
tool_name = "calendar_calendars_get"
|
||||
tool = toolset.get_tool(tool_name)
|
||||
assert isinstance(tool, RestApiTool)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == "Returns metadata for a calendar."
|
||||
assert tool.endpoint.method == "get"
|
||||
assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
|
||||
assert tool.endpoint.path == "/calendars/{calendarId}"
|
||||
assert tool.is_long_running is False
|
||||
assert tool.operation.operationId == "calendar.calendars.get"
|
||||
assert tool.operation.description == "Returns metadata for a calendar."
|
||||
assert len(tool.operation.parameters) == 1
|
||||
assert tool.operation.parameters[0].name == "calendarId"
|
||||
assert tool.operation.parameters[0].in_ == ParameterInType.path
|
||||
assert tool.operation.parameters[0].required is True
|
||||
assert tool.operation.parameters[0].schema_.type == "string"
|
||||
assert (
|
||||
tool.operation.parameters[0].description
|
||||
== "Calendar identifier. To retrieve calendar IDs call the"
|
||||
" calendarList.list method. If you want to access the primary calendar"
|
||||
' of the currently logged in user, use the "primary" keyword.'
|
||||
)
|
||||
assert isinstance(tool.auth_scheme, OAuth2)
|
||||
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
|
||||
assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
|
||||
|
||||
|
||||
def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
|
||||
"""Test the tool() method for a non-existing tool."""
|
||||
toolset = OpenAPIToolset(spec_dict=openapi_spec)
|
||||
tool = toolset.get_tool("non_existent_tool")
|
||||
assert tool is None
|
||||
|
||||
|
||||
def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
|
||||
"""Test configuring auth during initialization."""
|
||||
|
||||
auth_scheme = APIKey(**{
|
||||
"in": APIKeyIn.header, # Use alias name in dict
|
||||
"name": "api_key",
|
||||
"type": SecuritySchemeType.http,
|
||||
})
|
||||
auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
|
||||
toolset = OpenAPIToolset(
|
||||
spec_dict=openapi_spec,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
for tool in toolset.tools:
|
||||
assert tool.auth_scheme == auth_scheme
|
||||
assert tool.auth_credential == auth_credential
|
||||
@@ -0,0 +1,406 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Response
|
||||
from fastapi.openapi.models import Schema
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation() -> Operation:
|
||||
"""Fixture to provide a sample OpenAPI Operation object."""
|
||||
return Operation(
|
||||
operationId='test_operation',
|
||||
summary='Test Summary',
|
||||
description='Test Description',
|
||||
parameters=[
|
||||
Parameter(**{
|
||||
'name': 'param1',
|
||||
'in': 'query',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 1',
|
||||
}),
|
||||
Parameter(**{
|
||||
'name': 'param2',
|
||||
'in': 'header',
|
||||
'schema': Schema(type='string'),
|
||||
'description': 'Parameter 2',
|
||||
}),
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'prop1': Schema(
|
||||
type='string', description='Property 1'
|
||||
),
|
||||
'prop2': Schema(
|
||||
type='integer', description='Property 2'
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
},
|
||||
description='Request body description',
|
||||
),
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='string'))
|
||||
},
|
||||
),
|
||||
'400': Response(description='Client Error'),
|
||||
},
|
||||
security=[{'oauth2': ['resource: read', 'resource: write']}],
|
||||
)
|
||||
|
||||
|
||||
def test_operation_parser_initialization(sample_operation):
|
||||
"""Test initialization of OperationParser."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.operation == sample_operation
|
||||
assert len(parser.params) == 4 # 2 params + 2 request body props
|
||||
assert parser.return_value is not None
|
||||
|
||||
|
||||
def test_process_operation_parameters(sample_operation):
|
||||
"""Test _process_operation_parameters method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_operation_parameters()
|
||||
assert len(parser.params) == 2
|
||||
assert parser.params[0].original_name == 'param1'
|
||||
assert parser.params[0].param_location == 'query'
|
||||
assert parser.params[1].original_name == 'param2'
|
||||
assert parser.params[1].param_location == 'header'
|
||||
|
||||
|
||||
def test_process_request_body(sample_operation):
|
||||
"""Test _process_request_body method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 2 # 2 properties in request body
|
||||
assert parser.params[0].original_name == 'prop1'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
assert parser.params[1].original_name == 'prop2'
|
||||
assert parser.params[1].param_location == 'body'
|
||||
|
||||
|
||||
def test_process_request_body_array():
|
||||
"""Test _process_request_body method with array schema."""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
'application/json': MediaType(
|
||||
schema=Schema(
|
||||
type='array',
|
||||
items=Schema(
|
||||
type='object',
|
||||
properties={
|
||||
'item_prop1': Schema(
|
||||
type='string', description='Item Property 1'
|
||||
),
|
||||
'item_prop2': Schema(
|
||||
type='integer', description='Item Property 2'
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'array'
|
||||
assert parser.params[0].param_location == 'body'
|
||||
# Check that schema is correctly propagated and is a dictionary
|
||||
assert parser.params[0].param_schema.type == 'array'
|
||||
assert parser.params[0].param_schema.items.type == 'object'
|
||||
assert 'item_prop1' in parser.params[0].param_schema.items.properties
|
||||
assert 'item_prop2' in parser.params[0].param_schema.items.properties
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop1'].description
|
||||
== 'Item Property 1'
|
||||
)
|
||||
assert (
|
||||
parser.params[0].param_schema.items.properties['item_prop2'].description
|
||||
== 'Item Property 2'
|
||||
)
|
||||
|
||||
|
||||
def test_process_request_body_no_name():
|
||||
"""Test _process_request_body with a schema that has no properties (unnamed)"""
|
||||
operation = Operation(
|
||||
requestBody=RequestBody(
|
||||
content={'application/json': MediaType(schema=Schema(type='string'))}
|
||||
)
|
||||
)
|
||||
parser = OperationParser(operation, should_parse=False)
|
||||
parser._process_request_body()
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == '' # No name
|
||||
assert parser.params[0].param_location == 'body'
|
||||
|
||||
|
||||
def test_dedupe_param_names(sample_operation):
|
||||
"""Test _dedupe_param_names method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
# Add duplicate named parameters.
|
||||
parser.params = [
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
ApiParameter(original_name='test', param_location='', param_schema={}),
|
||||
]
|
||||
parser._dedupe_param_names()
|
||||
assert parser.params[0].py_name == 'test'
|
||||
assert parser.params[1].py_name == 'test_0'
|
||||
assert parser.params[2].py_name == 'test_1'
|
||||
|
||||
|
||||
def test_process_return_value(sample_operation):
|
||||
"""Test _process_return_value method."""
|
||||
parser = OperationParser(sample_operation, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
|
||||
|
||||
def test_process_return_value_no_2xx(sample_operation):
|
||||
"""Tests _process_return_value when no 2xx response exists."""
|
||||
operation_no_2xx = Operation(
|
||||
responses={'400': Response(description='Client Error')}
|
||||
)
|
||||
parser = OperationParser(operation_no_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value is not None
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_multiple_2xx(sample_operation):
|
||||
"""Tests _process_return_value when multiple 2xx responses exist."""
|
||||
operation_multi_2xx = Operation(
|
||||
responses={
|
||||
'201': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/json': MediaType(schema=Schema(type='integer'))
|
||||
},
|
||||
),
|
||||
'202': Response(
|
||||
description='Success',
|
||||
content={'text/plain': MediaType(schema=Schema(type='string'))},
|
||||
),
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={
|
||||
'application/pdf': MediaType(schema=Schema(type='boolean'))
|
||||
},
|
||||
),
|
||||
'400': Response(
|
||||
description='Failure',
|
||||
content={
|
||||
'application/xml': MediaType(schema=Schema(type='object'))
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
parser = OperationParser(operation_multi_2xx, should_parse=False)
|
||||
parser._process_return_value()
|
||||
|
||||
assert parser.return_value is not None
|
||||
# Take the content type of the 200 response since it's the smallest response
|
||||
# code
|
||||
assert parser.return_value.param_schema.type == 'boolean'
|
||||
|
||||
|
||||
def test_process_return_value_no_content(sample_operation):
|
||||
"""Test when 2xx response has no content"""
|
||||
operation_no_content = Operation(
|
||||
responses={'200': Response(description='Success', content={})}
|
||||
)
|
||||
parser = OperationParser(operation_no_content, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_process_return_value_no_schema(sample_operation):
|
||||
"""Tests when the 2xx response's content has no schema."""
|
||||
operation_no_schema = Operation(
|
||||
responses={
|
||||
'200': Response(
|
||||
description='Success',
|
||||
content={'application/json': MediaType(schema=None)},
|
||||
)
|
||||
}
|
||||
)
|
||||
parser = OperationParser(operation_no_schema, should_parse=False)
|
||||
parser._process_return_value()
|
||||
assert parser.return_value.type_hint == 'Any'
|
||||
|
||||
|
||||
def test_get_function_name(sample_operation):
|
||||
"""Test get_function_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_function_name() == 'test_operation'
|
||||
|
||||
|
||||
def test_get_function_name_missing_id():
|
||||
"""Tests get_function_name when operationId is missing"""
|
||||
operation = Operation() # No ID
|
||||
parser = OperationParser(operation)
|
||||
with pytest.raises(ValueError, match='Operation ID is missing'):
|
||||
parser.get_function_name()
|
||||
|
||||
|
||||
def test_get_return_type_hint(sample_operation):
|
||||
"""Test get_return_type_hint method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_hint() == 'str'
|
||||
|
||||
|
||||
def test_get_return_type_value(sample_operation):
|
||||
"""Test get_return_type_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_return_type_value() == str
|
||||
|
||||
|
||||
def test_get_parameters(sample_operation):
|
||||
"""Test get_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
params = parser.get_parameters()
|
||||
assert len(params) == 4 # Correct count after processing
|
||||
assert all(isinstance(p, ApiParameter) for p in params)
|
||||
|
||||
|
||||
def test_get_return_value(sample_operation):
|
||||
"""Test get_return_value method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
return_value = parser.get_return_value()
|
||||
assert isinstance(return_value, ApiParameter)
|
||||
|
||||
|
||||
def test_get_auth_scheme_name(sample_operation):
|
||||
"""Test get_auth_scheme_name method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
assert parser.get_auth_scheme_name() == 'oauth2'
|
||||
|
||||
|
||||
def test_get_auth_scheme_name_no_security():
|
||||
"""Test get_auth_scheme_name when no security is present."""
|
||||
operation = Operation(responses={})
|
||||
parser = OperationParser(operation)
|
||||
assert parser.get_auth_scheme_name() == ''
|
||||
|
||||
|
||||
def test_get_pydoc_string(sample_operation):
|
||||
"""Test get_pydoc_string method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
pydoc_string = parser.get_pydoc_string()
|
||||
assert 'Test Summary' in pydoc_string
|
||||
assert 'Args:' in pydoc_string
|
||||
assert 'param1 (str): Parameter 1' in pydoc_string
|
||||
assert 'prop1 (str): Property 1' in pydoc_string
|
||||
assert 'Returns (str):' in pydoc_string
|
||||
assert 'Success' in pydoc_string
|
||||
|
||||
|
||||
def test_get_json_schema(sample_operation):
|
||||
"""Test get_json_schema method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
json_schema = parser.get_json_schema()
|
||||
assert json_schema['title'] == 'test_operation_Arguments'
|
||||
assert json_schema['type'] == 'object'
|
||||
assert 'param1' in json_schema['properties']
|
||||
assert 'prop1' in json_schema['properties']
|
||||
assert 'param1' in json_schema['required']
|
||||
assert 'prop1' in json_schema['required']
|
||||
|
||||
|
||||
def test_get_signature_parameters(sample_operation):
|
||||
"""Test get_signature_parameters method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
signature_params = parser.get_signature_parameters()
|
||||
assert len(signature_params) == 4
|
||||
assert signature_params[0].name == 'param1'
|
||||
assert signature_params[0].annotation == str
|
||||
assert signature_params[2].name == 'prop1'
|
||||
assert signature_params[2].annotation == str
|
||||
|
||||
|
||||
def test_get_annotations(sample_operation):
|
||||
"""Test get_annotations method."""
|
||||
parser = OperationParser(sample_operation)
|
||||
annotations = parser.get_annotations()
|
||||
assert len(annotations) == 5 # 4 parameters + return
|
||||
assert annotations['param1'] == str
|
||||
assert annotations['prop1'] == str
|
||||
assert annotations['return'] == str
|
||||
|
||||
|
||||
def test_load():
|
||||
"""Test the load classmethod."""
|
||||
operation = Operation(operationId='my_op') # Minimal operation
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name='p1',
|
||||
param_location='',
|
||||
param_schema={'type': 'integer'},
|
||||
)
|
||||
]
|
||||
return_value = ApiParameter(
|
||||
original_name='', param_location='', param_schema={'type': 'string'}
|
||||
)
|
||||
|
||||
parser = OperationParser.load(operation, params, return_value)
|
||||
|
||||
assert isinstance(parser, OperationParser)
|
||||
assert parser.operation == operation
|
||||
assert parser.params == params
|
||||
assert parser.return_value == return_value
|
||||
assert (
|
||||
parser.get_function_name() == 'my_op'
|
||||
) # Check that the operation is loaded
|
||||
|
||||
|
||||
def test_operation_parser_with_dict():
|
||||
"""Test initialization of OperationParser with a dictionary."""
|
||||
operation_dict = {
|
||||
'operationId': 'test_dict_operation',
|
||||
'parameters': [
|
||||
{'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
|
||||
],
|
||||
'responses': {
|
||||
'200': {
|
||||
'description': 'Dict Success',
|
||||
'content': {'application/json': {'schema': {'type': 'string'}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
parser = OperationParser(operation_dict)
|
||||
assert parser.operation.operationId == 'test_dict_operation'
|
||||
assert len(parser.params) == 1
|
||||
assert parser.params[0].original_name == 'dict_param'
|
||||
assert parser.return_value.type_hint == 'str'
|
||||
@@ -0,0 +1,966 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi.models import MediaType
|
||||
from fastapi.openapi.models import Operation
|
||||
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
||||
from fastapi.openapi.models import RequestBody
|
||||
from fastapi.openapi.models import Schema as OpenAPISchema
|
||||
from google.adk.sessions.state import State
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from google.genai.types import Schema
|
||||
from google.genai.types import Type
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRestApiTool:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_context = MagicMock(spec=ToolContext)
|
||||
mock_context.state = State({}, {})
|
||||
mock_context.get_auth_response.return_value = {}
|
||||
mock_context.request_credential.return_value = {}
|
||||
return mock_context
|
||||
|
||||
@pytest.fixture
|
||||
def mock_operation_parser(self):
|
||||
"""Fixture for a mock OperationParser."""
|
||||
mock_parser = MagicMock(spec=OperationParser)
|
||||
mock_parser.get_function_name.return_value = "mock_function_name"
|
||||
mock_parser.get_json_schema.return_value = {}
|
||||
mock_parser.get_parameters.return_value = []
|
||||
mock_parser.get_return_type_hint.return_value = "str"
|
||||
mock_parser.get_pydoc_string.return_value = "Mock docstring"
|
||||
mock_parser.get_signature_parameters.return_value = []
|
||||
mock_parser.get_return_type_value.return_value = str
|
||||
mock_parser.get_annotations.return_value = {}
|
||||
return mock_parser
|
||||
|
||||
@pytest.fixture
|
||||
def sample_endpiont(self):
|
||||
return OperationEndpoint(
|
||||
base_url="https://example.com", path="/test", method="GET"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_operation(self):
|
||||
return Operation(
|
||||
operationId="testOperation",
|
||||
description="Test operation",
|
||||
parameters=[],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"testBodyParam": OpenAPISchema(type="string")
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_api_parameters(self):
|
||||
return [
|
||||
ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="test_body_param",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_return_parameter(self):
|
||||
return ApiParameter(
|
||||
original_name="test_param",
|
||||
py_name="test_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
is_required=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_scheme(self):
|
||||
scheme, _ = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return scheme
|
||||
|
||||
@pytest.fixture
|
||||
def sample_auth_credential(self):
|
||||
_, credential = token_to_scheme_credential(
|
||||
"apikey", "header", "", "sample_auth_credential_internal_test"
|
||||
)
|
||||
return credential
|
||||
|
||||
def test_init(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "Test Tool"
|
||||
assert tool.endpoint == sample_endpiont
|
||||
assert tool.operation == sample_operation
|
||||
assert tool.auth_credential == sample_auth_credential
|
||||
assert tool.auth_scheme == sample_auth_scheme
|
||||
assert tool.credential_exchanger is not None
|
||||
|
||||
def test_from_parsed_operation_str(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_api_parameters,
|
||||
sample_return_parameter,
|
||||
sample_operation,
|
||||
):
|
||||
parsed_operation_str = json.dumps({
|
||||
"name": "test_operation",
|
||||
"description": "Test Description",
|
||||
"endpoint": sample_endpiont.model_dump(),
|
||||
"operation": sample_operation.model_dump(),
|
||||
"auth_scheme": None,
|
||||
"auth_credential": None,
|
||||
"parameters": [p.model_dump() for p in sample_api_parameters],
|
||||
"return_value": sample_return_parameter.model_dump(),
|
||||
})
|
||||
|
||||
tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
|
||||
assert tool.name == "test_operation"
|
||||
|
||||
def test_get_declaration(
|
||||
self, sample_endpiont, sample_operation, mock_operation_parser
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test description",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
should_parse_operation=False,
|
||||
)
|
||||
tool._operation_parser = mock_operation_parser
|
||||
|
||||
declaration = tool._get_declaration()
|
||||
assert isinstance(declaration, FunctionDeclaration)
|
||||
assert declaration.name == "test_tool"
|
||||
assert declaration.description == "Test description"
|
||||
assert isinstance(declaration.parameters, Schema)
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_success(
|
||||
self,
|
||||
mock_request,
|
||||
mock_tool_context,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = tool.call(args={}, tool_context=mock_tool_context)
|
||||
|
||||
# Check the result
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
|
||||
)
|
||||
def test_call_auth_pending(
|
||||
self,
|
||||
mock_request,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
sample_auth_scheme,
|
||||
sample_auth_credential,
|
||||
):
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
auth_credential=sample_auth_credential,
|
||||
)
|
||||
with patch(
|
||||
"google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
|
||||
) as mock_from_tool_context:
|
||||
mock_tool_auth_handler_instance = MagicMock()
|
||||
mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
|
||||
"pending"
|
||||
)
|
||||
mock_from_tool_context.return_value = mock_tool_auth_handler_instance
|
||||
|
||||
response = tool.call(args={}, tool_context=None)
|
||||
assert response == {
|
||||
"pending": True,
|
||||
"message": "Needs your authorization to access your data.",
|
||||
}
|
||||
|
||||
def test_prepare_request_params_query_body(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Create a mock Operation object
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
parameters=[
|
||||
OpenAPIParameter(**{
|
||||
"name": "testQueryParam",
|
||||
"in": "query",
|
||||
"schema": OpenAPISchema(type="string"),
|
||||
})
|
||||
],
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"param1": OpenAPISchema(type="string"),
|
||||
"param2": OpenAPISchema(type="integer"),
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param1",
|
||||
py_name="param1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="param2",
|
||||
py_name="param2",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="integer"),
|
||||
),
|
||||
ApiParameter(
|
||||
original_name="testQueryParam",
|
||||
py_name="test_query_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
),
|
||||
]
|
||||
kwargs = {
|
||||
"param1": "value1",
|
||||
"param2": 123,
|
||||
"test_query_param": "query_value",
|
||||
}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
assert request_params["method"] == "get"
|
||||
assert request_params["url"] == "https://example.com/test"
|
||||
assert request_params["json"] == {"param1": "value1", "param2": 123}
|
||||
assert request_params["params"] == {"testQueryParam": "query_value"}
|
||||
|
||||
def test_prepare_request_params_array(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="array", # Match the parameter name
|
||||
py_name="array",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(
|
||||
type="array", items=OpenAPISchema(type="string")
|
||||
),
|
||||
)
|
||||
]
|
||||
kwargs = {"array": ["item1", "item2"]}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["json"] == ["item1", "item2"]
|
||||
|
||||
def test_prepare_request_params_string(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string"))
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input_string",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input_string": "test_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == "test_value"
|
||||
assert request_params["headers"]["Content-Type"] == "text/plain"
|
||||
|
||||
def test_prepare_request_params_form_data(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/x-www-form-urlencoded": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={"key1": OpenAPISchema(type="string")},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="key1",
|
||||
py_name="key1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"key1": "value1"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == {"key1": "value1"}
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"]
|
||||
== "application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_multipart(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"multipart/form-data": MediaType(
|
||||
schema=OpenAPISchema(
|
||||
type="object",
|
||||
properties={
|
||||
"file1": OpenAPISchema(
|
||||
type="string", format="binary"
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="file1",
|
||||
py_name="file1",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"file1": b"file_content"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["files"] == {"file1": b"file_content"}
|
||||
assert request_params["headers"]["Content-Type"] == "multipart/form-data"
|
||||
|
||||
def test_prepare_request_params_octet_stream(
|
||||
self, sample_endpiont, sample_auth_scheme, sample_auth_credential
|
||||
):
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/octet-stream": MediaType(
|
||||
schema=OpenAPISchema(type="string", format="binary")
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="test",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="data",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string", format="binary"),
|
||||
)
|
||||
]
|
||||
kwargs = {"data": b"binary_data"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["data"] == b"binary_data"
|
||||
assert (
|
||||
request_params["headers"]["Content-Type"] == "application/octet-stream"
|
||||
)
|
||||
|
||||
def test_prepare_request_params_path_param(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
mock_operation = Operation(operationId="test_op")
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="user_id",
|
||||
py_name="user_id",
|
||||
param_location="path",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"user_id": "123"}
|
||||
endpoint_with_path = OperationEndpoint(
|
||||
base_url="https://example.com", path="/test/{user_id}", method="get"
|
||||
)
|
||||
tool.endpoint = endpoint_with_path
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert (
|
||||
request_params["url"] == "https://example.com/test/123"
|
||||
) # Path param replaced
|
||||
|
||||
def test_prepare_request_params_header_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="X-Custom-Header",
|
||||
py_name="x_custom_header",
|
||||
param_location="header",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"x_custom_header": "header_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["X-Custom-Header"] == "header_value"
|
||||
|
||||
def test_prepare_request_params_cookie_param(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="session_id",
|
||||
py_name="session_id",
|
||||
param_location="cookie",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"session_id": "cookie_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["cookies"]["session_id"] == "cookie_value"
|
||||
|
||||
def test_prepare_request_params_multiple_mime_types(
|
||||
self, sample_endpiont, sample_auth_credential, sample_auth_scheme
|
||||
):
|
||||
# Test what happens when multiple mime types are specified. It should take
|
||||
# the first one.
|
||||
mock_operation = Operation(
|
||||
operationId="test_op",
|
||||
requestBody=RequestBody(
|
||||
content={
|
||||
"application/json": MediaType(
|
||||
schema=OpenAPISchema(type="string")
|
||||
),
|
||||
"text/plain": MediaType(schema=OpenAPISchema(type="string")),
|
||||
}
|
||||
),
|
||||
)
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=mock_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="",
|
||||
py_name="input",
|
||||
param_location="body",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"input": "some_value"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert request_params["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
def test_prepare_request_params_unknown_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="known_param",
|
||||
py_name="known_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"known_param": "value", "unknown_param": "unknown"}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Make sure unknown parameters are ignored and do not raise errors.
|
||||
assert "unknown_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_base_url_handling(
|
||||
self, sample_auth_credential, sample_auth_scheme, sample_operation
|
||||
):
|
||||
# No base_url provided, should use path as is
|
||||
tool_no_base = RestApiTool(
|
||||
name="test_tool_no_base",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = []
|
||||
kwargs = {}
|
||||
|
||||
request_params_no_base = tool_no_base._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_no_base["url"] == "/no_base"
|
||||
|
||||
tool_trailing_slash = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=OperationEndpoint(
|
||||
base_url="https://example.com/", path="/trailing", method="get"
|
||||
),
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
|
||||
request_params_trailing = tool_trailing_slash._prepare_request_params(
|
||||
params, kwargs
|
||||
)
|
||||
assert request_params_trailing["url"] == "https://example.com/trailing"
|
||||
|
||||
def test_prepare_request_params_no_unrecognized_query_parameter(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_auth_credential,
|
||||
sample_auth_scheme,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=sample_auth_credential,
|
||||
auth_scheme=sample_auth_scheme,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="unrecognized_param",
|
||||
py_name="unrecognized_param",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"unrecognized_param": None} # Explicitly passing None
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
# Query param not in sample_operation. It should be ignored.
|
||||
assert "unrecognized_param" not in request_params["params"]
|
||||
|
||||
def test_prepare_request_params_no_credential(
|
||||
self,
|
||||
sample_endpiont,
|
||||
sample_operation,
|
||||
):
|
||||
tool = RestApiTool(
|
||||
name="test_tool",
|
||||
description="Test Tool",
|
||||
endpoint=sample_endpiont,
|
||||
operation=sample_operation,
|
||||
auth_credential=None,
|
||||
auth_scheme=None,
|
||||
)
|
||||
params = [
|
||||
ApiParameter(
|
||||
original_name="param_name",
|
||||
py_name="param_name",
|
||||
param_location="query",
|
||||
param_schema=OpenAPISchema(type="string"),
|
||||
)
|
||||
]
|
||||
kwargs = {"param_name": "aaa", "empty_param": ""}
|
||||
|
||||
request_params = tool._prepare_request_params(params, kwargs)
|
||||
|
||||
assert "param_name" in request_params["params"]
|
||||
assert "empty_param" not in request_params["params"]
|
||||
|
||||
|
||||
class TestToGeminiSchema:
|
||||
|
||||
def test_to_gemini_schema_none(self):
|
||||
assert to_gemini_schema(None) is None
|
||||
|
||||
def test_to_gemini_schema_not_dict(self):
|
||||
with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
|
||||
to_gemini_schema("not a dict")
|
||||
|
||||
def test_to_gemini_schema_empty_dict(self):
|
||||
result = to_gemini_schema({})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_dict_with_only_object_type(self):
|
||||
result = to_gemini_schema({"type": "object"})
|
||||
assert isinstance(result, Schema)
|
||||
assert result.type == Type.OBJECT
|
||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
||||
|
||||
def test_to_gemini_schema_basic_types(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"is_active": {"type": "boolean"},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert isinstance(gemini_schema, Schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["age"].type == Type.INTEGER
|
||||
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
||||
|
||||
def test_to_gemini_schema_nested_objects(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["address"].type == Type.OBJECT
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["street"].type
|
||||
== Type.STRING
|
||||
)
|
||||
assert (
|
||||
gemini_schema.properties["address"].properties["city"].type
|
||||
== Type.STRING
|
||||
)
|
||||
|
||||
def test_to_gemini_schema_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.ARRAY
|
||||
assert gemini_schema.items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_nested_array(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.items.properties["name"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_any_of(self):
|
||||
openapi_schema = {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert len(gemini_schema.any_of) == 2
|
||||
assert gemini_schema.any_of[0].type == Type.STRING
|
||||
assert gemini_schema.any_of[1].type == Type.INTEGER
|
||||
|
||||
def test_to_gemini_schema_general_list(self):
|
||||
openapi_schema = {
|
||||
"type": "array",
|
||||
"properties": {
|
||||
"list_field": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.properties["list_field"].type == Type.ARRAY
|
||||
assert gemini_schema.properties["list_field"].items.type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_enum(self):
|
||||
openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.enum == ["a", "b", "c"]
|
||||
|
||||
def test_to_gemini_schema_required(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.required == ["name"]
|
||||
|
||||
def test_to_gemini_schema_nested_dict(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"properties": {"metadata": {"key1": "value1", "key2": 123}},
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
||||
assert isinstance(gemini_schema.properties["metadata"], Schema)
|
||||
assert (
|
||||
gemini_schema.properties["metadata"].type == Type.OBJECT
|
||||
) # add object type by default
|
||||
assert gemini_schema.properties["metadata"].properties == {
|
||||
"dummy_DO_NOT_GENERATE": Schema(type="string")
|
||||
}
|
||||
|
||||
def test_to_gemini_schema_ignore_title_default_format(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"title": "Test Title",
|
||||
"default": "default_value",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
|
||||
assert gemini_schema.title is None
|
||||
assert gemini_schema.default is None
|
||||
assert gemini_schema.format is None
|
||||
|
||||
def test_to_gemini_schema_property_ordering(self):
|
||||
openapi_schema = {
|
||||
"type": "object",
|
||||
"propertyOrdering": ["name", "age"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.property_ordering == ["name", "age"]
|
||||
|
||||
def test_to_gemini_schema_converts_property_dict(self):
|
||||
openapi_schema = {
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The property key"},
|
||||
"value": {"type": "string", "description": "The property value"},
|
||||
},
|
||||
"type": "object",
|
||||
"description": "A single property entry in the Properties message.",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.OBJECT
|
||||
assert gemini_schema.properties["name"].type == Type.STRING
|
||||
assert gemini_schema.properties["value"].type == Type.STRING
|
||||
|
||||
def test_to_gemini_schema_remove_unrecognized_fields(self):
|
||||
openapi_schema = {
|
||||
"type": "string",
|
||||
"description": "A single date string.",
|
||||
"format": "date",
|
||||
}
|
||||
gemini_schema = to_gemini_schema(openapi_schema)
|
||||
assert gemini_schema.type == Type.STRING
|
||||
assert not gemini_schema.format
|
||||
|
||||
|
||||
def test_snake_to_lower_camel():
|
||||
assert snake_to_lower_camel("single") == "single"
|
||||
assert snake_to_lower_camel("two_words") == "twoWords"
|
||||
assert snake_to_lower_camel("three_word_example") == "threeWordExample"
|
||||
assert not snake_to_lower_camel("")
|
||||
assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"
|
||||
@@ -0,0 +1,201 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_schemes import AuthScheme
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
|
||||
from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
# Helper function to create a mock ToolContext
|
||||
def create_mock_tool_context():
|
||||
return ToolContext(
|
||||
function_call_id='test-fc-id',
|
||||
invocation_context=InvocationContext(
|
||||
agent=LlmAgent(name='test'),
|
||||
session=Session(app_name='test', user_id='123', id='123'),
|
||||
invocation_id='123',
|
||||
session_service=InMemorySessionService(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Test cases for OpenID Connect
|
||||
class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
|
||||
|
||||
def __init__(
|
||||
self, expected_scheme, expected_credential, expected_access_token
|
||||
):
|
||||
self.expected_scheme = expected_scheme
|
||||
self.expected_credential = expected_credential
|
||||
self.expected_access_token = expected_access_token
|
||||
|
||||
def exchange_credential(
|
||||
self,
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
if auth_credential.oauth2 and (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
or auth_credential.oauth2.auth_code
|
||||
):
|
||||
auth_code = (
|
||||
auth_credential.oauth2.auth_response_uri
|
||||
if auth_credential.oauth2.auth_response_uri
|
||||
else auth_credential.oauth2.auth_code
|
||||
)
|
||||
# Simulate the token exchange
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
http=HttpAuth(
|
||||
scheme='bearer',
|
||||
credentials=HttpCredentials(
|
||||
token=auth_code + self.expected_access_token
|
||||
),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
# simulate the case of getting auth_uri
|
||||
return None
|
||||
|
||||
|
||||
def get_mock_openid_scheme_credential():
|
||||
config_dict = {
|
||||
'authorization_endpoint': 'test.com',
|
||||
'token_endpoint': 'test.com',
|
||||
}
|
||||
scopes = ['test_scope']
|
||||
credential_dict = {
|
||||
'client_id': '123',
|
||||
'client_secret': '456',
|
||||
'redirect_uri': 'test.com',
|
||||
}
|
||||
return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
|
||||
|
||||
|
||||
# Fixture for the OpenID Connect security scheme
|
||||
@pytest.fixture
|
||||
def openid_connect_scheme():
|
||||
scheme, _ = get_mock_openid_scheme_credential()
|
||||
return scheme
|
||||
|
||||
|
||||
# Fixture for a base OpenID Connect credential
|
||||
@pytest.fixture
|
||||
def openid_connect_credential():
|
||||
_, credential = get_mock_openid_scheme_credential()
|
||||
return credential
|
||||
|
||||
|
||||
def test_openid_connect_no_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
# Setup Mock exchanger
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme, openid_connect_credential, None
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'pending'
|
||||
assert result.auth_credential == openid_connect_credential
|
||||
|
||||
|
||||
def test_openid_connect_with_auth_response(
|
||||
openid_connect_scheme, openid_connect_credential, monkeypatch
|
||||
):
|
||||
mock_exchanger = MockOpenIdConnectCredentialExchanger(
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
'test_access_token',
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
|
||||
mock_auth_handler = MagicMock()
|
||||
mock_auth_handler.get_auth_response.return_value = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
|
||||
)
|
||||
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
|
||||
monkeypatch.setattr(
|
||||
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
|
||||
)
|
||||
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_exchanger=mock_exchanger,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
|
||||
assert 'test_access_token' in result.auth_credential.http.credentials.token
|
||||
# Verify that the credential was stored:
|
||||
stored_credential = credential_store.get_credential(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
assert stored_credential == result.auth_credential
|
||||
mock_auth_handler.get_auth_response.assert_called_once()
|
||||
|
||||
|
||||
def test_openid_connect_existing_token(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
):
|
||||
_, existing_credential = token_to_scheme_credential(
|
||||
'oauth2Token', 'header', 'bearer', '123123123'
|
||||
)
|
||||
tool_context = create_mock_tool_context()
|
||||
# Store the credential to simulate existing credential
|
||||
credential_store = ToolContextCredentialStore(tool_context=tool_context)
|
||||
key = credential_store.get_credential_key(
|
||||
openid_connect_scheme, openid_connect_credential
|
||||
)
|
||||
credential_store.store_credential(key, existing_credential)
|
||||
|
||||
handler = ToolAuthHandler(
|
||||
tool_context,
|
||||
openid_connect_scheme,
|
||||
openid_connect_credential,
|
||||
credential_store=credential_store,
|
||||
)
|
||||
result = handler.prepare_auth_credentials()
|
||||
assert result.state == 'done'
|
||||
assert result.auth_credential == existing_credential
|
||||
14
tests/unittests/tools/retrieval/__init__.py
Normal file
14
tests/unittests/tools/retrieval/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
147
tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
Normal file
147
tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def noop_tool(x: str) -> str:
|
||||
return x
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_1_x():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-1.5-pro'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools) == 1
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[0].name
|
||||
== 'rag_retrieval'
|
||||
)
|
||||
assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_1_x_with_another_function_tool():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-1.5-pro'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
),
|
||||
FunctionTool(func=noop_tool),
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools[0].function_declarations) == 2
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[0].name
|
||||
== 'rag_retrieval'
|
||||
)
|
||||
assert (
|
||||
mockModel.requests[0].config.tools[0].function_declarations[1].name
|
||||
== 'noop_tool'
|
||||
)
|
||||
assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
|
||||
|
||||
|
||||
def test_vertex_rag_retrieval_for_gemini_2_x():
|
||||
responses = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
mockModel.model = 'gemini-2.0-flash'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[
|
||||
VertexAiRagRetrieval(
|
||||
name='rag_retrieval',
|
||||
description='rag_retrieval',
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 1
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test1'),
|
||||
]
|
||||
assert len(mockModel.requests[0].config.tools) == 1
|
||||
assert mockModel.requests[0].config.tools == [
|
||||
types.Tool(
|
||||
retrieval=types.Retrieval(
|
||||
vertex_rag_store=types.VertexRagStore(
|
||||
rag_corpora=[
|
||||
'projects/123456789/locations/us-central1/ragCorpora/1234567890'
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
assert 'rag_retrieval' not in mockModel.requests[0].tools_dict
|
||||
167
tests/unittests/tools/test_agent_tool.py
Normal file
167
tests/unittests/tools/test_agent_tool.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.tools.agent_tool import AgentTool
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
from .. import utils
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason='Skipping until tool.func evaluations are fixed (async)'
|
||||
)
|
||||
|
||||
|
||||
function_call_custom = Part.from_function_call(
|
||||
name='tool_agent', args={'custom_input': 'test1'}
|
||||
)
|
||||
|
||||
function_call_no_schema = Part.from_function_call(
|
||||
name='tool_agent', args={'request': 'test1'}
|
||||
)
|
||||
|
||||
function_response_custom = Part.from_function_response(
|
||||
name='tool_agent', response={'custom_output': 'response1'}
|
||||
)
|
||||
|
||||
function_response_no_schema = Part.from_function_response(
|
||||
name='tool_agent', response={'result': 'response1'}
|
||||
)
|
||||
|
||||
|
||||
def change_state_callback(callback_context: CallbackContext):
|
||||
callback_context.state['state_1'] = 'changed_value'
|
||||
print('change_state_callback: ', callback_context.state)
|
||||
|
||||
|
||||
def test_no_schema():
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_no_schema,
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', function_call_no_schema),
|
||||
('root_agent', function_response_no_schema),
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_update_state():
|
||||
"""The agent tool can read and change parent state."""
|
||||
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_no_schema,
|
||||
'{"custom_output": "response1"}',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
instruction='input: {state_1}',
|
||||
before_agent_callback=change_state_callback,
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
runner.session.state['state_1'] = 'state1_value'
|
||||
|
||||
runner.run('test1')
|
||||
assert (
|
||||
'input: changed_value' in mock_model.requests[1].config.system_instruction
|
||||
)
|
||||
assert runner.session.state['state_1'] == 'changed_value'
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
'env_variables',
|
||||
[
|
||||
'GOOGLE_AI',
|
||||
# TODO(wanyif): re-enable after fix.
|
||||
# 'VERTEX',
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_custom_schema():
|
||||
class CustomInput(BaseModel):
|
||||
custom_input: str
|
||||
|
||||
class CustomOutput(BaseModel):
|
||||
custom_output: str
|
||||
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
function_call_custom,
|
||||
'{"custom_output": "response1"}',
|
||||
'response2',
|
||||
]
|
||||
)
|
||||
|
||||
tool_agent = Agent(
|
||||
name='tool_agent',
|
||||
model=mock_model,
|
||||
input_schema=CustomInput,
|
||||
output_schema=CustomOutput,
|
||||
output_key='tool_output',
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[AgentTool(agent=tool_agent)],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
runner.session.state['state_1'] = 'state1_value'
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', function_call_custom),
|
||||
('root_agent', function_response_custom),
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
assert runner.session.state['tool_output'] == {'custom_output': 'response1'}
|
||||
|
||||
assert len(mock_model.requests) == 3
|
||||
# The second request is the tool agent request.
|
||||
assert mock_model.requests[1].config.response_schema == CustomOutput
|
||||
assert mock_model.requests[1].config.response_mime_type == 'application/json'
|
||||
141
tests/unittests/tools/test_base_tool.py
Normal file
141
tests/unittests/tools/test_base_tool.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class _TestingTool(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
declaration: Optional[types.FunctionDeclaration] = None,
|
||||
):
|
||||
super().__init__(name='test_tool', description='test_description')
|
||||
self.declaration = declaration
|
||||
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
return self.declaration
|
||||
|
||||
|
||||
def _create_tool_context() -> ToolContext:
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
agent = SequentialAgent(name='test_agent')
|
||||
invocation_context = InvocationContext(
|
||||
invocation_id='invocation_id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
)
|
||||
return ToolContext(invocation_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_no_declaration():
|
||||
tool = _TestingTool()
|
||||
tool_context = _create_tool_context()
|
||||
llm_request = LlmRequest()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
assert llm_request.config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_declaration():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest()
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
assert llm_request.config.tools[0].function_declarations == [declaration]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_builtin_tool():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest(
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[types.Tool(google_search=types.GoogleSearch())]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# function_declaration is added to another types.Tool without builtin tool.
|
||||
assert llm_request.config.tools[1].function_declarations == [declaration]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_with_builtin_tool_and_another_declaration():
|
||||
declaration = types.FunctionDeclaration(
|
||||
name='test_tool',
|
||||
description='test_description',
|
||||
parameters=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
title='param_1',
|
||||
),
|
||||
)
|
||||
tool = _TestingTool(declaration)
|
||||
llm_request = LlmRequest(
|
||||
config=types.GenerateContentConfig(
|
||||
tools=[
|
||||
types.Tool(google_search=types.GoogleSearch()),
|
||||
types.Tool(function_declarations=[types.FunctionDeclaration()]),
|
||||
]
|
||||
)
|
||||
)
|
||||
tool_context = _create_tool_context()
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# function_declaration is added to existing types.Tool with function_declaration.
|
||||
assert llm_request.config.tools[1].function_declarations[1] == declaration
|
||||
277
tests/unittests/tools/test_build_function_declaration.py
Normal file
277
tests/unittests/tools/test_build_function_declaration.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from google.adk.tools import _automatic_function_calling_util
|
||||
from google.adk.tools.agent_tool import ToolContext
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
# TODO: crewai requires python 3.10 as minimum
|
||||
# from crewai_tools import FileReadTool
|
||||
from langchain_community.tools import ShellTool
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def test_unsupported_variant():
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function, variant='Unsupported'
|
||||
)
|
||||
|
||||
|
||||
def test_string_input():
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
|
||||
|
||||
def test_int_input():
|
||||
def simple_function(input_str: int) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'INTEGER'
|
||||
|
||||
|
||||
def test_float_input():
|
||||
def simple_function(input_str: float) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'NUMBER'
|
||||
|
||||
|
||||
def test_bool_input():
|
||||
def simple_function(input_str: bool) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'BOOLEAN'
|
||||
|
||||
|
||||
def test_array_input():
|
||||
def simple_function(input_str: List[str]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
|
||||
|
||||
def test_dict_input():
|
||||
def simple_function(input_str: Dict[str, str]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'OBJECT'
|
||||
|
||||
|
||||
def test_basemodel_input():
|
||||
class CustomInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
def simple_function(input: CustomInput) -> str:
|
||||
return {'result': input}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input'].properties['input_str'].type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
def test_toolcontext_ignored():
|
||||
def simple_function(input_str: str, tool_context: ToolContext) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function, ignore_params=['tool_context']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
assert 'tool_context' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_basemodel():
|
||||
class SimpleFunction(BaseModel):
|
||||
input_str: str
|
||||
custom_input: int
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=SimpleFunction, ignore_params=['custom_input']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'SimpleFunction'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'STRING'
|
||||
assert 'custom_input' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_nested_basemodel_input():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
def simple_function(input: CustomInput) -> str:
|
||||
return {'result': input}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input'].properties['child'].type
|
||||
== 'OBJECT'
|
||||
)
|
||||
assert (
|
||||
function_decl.parameters.properties['input']
|
||||
.properties['child']
|
||||
.properties['input_str']
|
||||
.type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
def test_basemodel_with_nested_basemodel():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=CustomInput, ignore_params=['custom_input']
|
||||
)
|
||||
|
||||
assert function_decl.name == 'CustomInput'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['child'].type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['child'].properties['input_str'].type
|
||||
== 'STRING'
|
||||
)
|
||||
assert 'custom_input' not in function_decl.parameters.properties
|
||||
|
||||
|
||||
def test_list():
|
||||
def simple_function(
|
||||
input_str: List[str], input_dir: List[Dict[str, str]]
|
||||
) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_str'].items.type == 'STRING'
|
||||
assert function_decl.parameters.properties['input_dir'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
|
||||
|
||||
|
||||
def test_basemodel_list():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
class CustomInput(BaseModel):
|
||||
child: ChildInput
|
||||
|
||||
def simple_function(input_str: List[CustomInput]) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
|
||||
assert function_decl.parameters.properties['input_str'].items.type == 'OBJECT'
|
||||
assert (
|
||||
function_decl.parameters.properties['input_str']
|
||||
.items.properties['child']
|
||||
.type
|
||||
== 'OBJECT'
|
||||
)
|
||||
assert (
|
||||
function_decl.parameters.properties['input_str']
|
||||
.items.properties['child']
|
||||
.properties['input_str']
|
||||
.type
|
||||
== 'STRING'
|
||||
)
|
||||
|
||||
|
||||
# TODO: comment out this test for now as crewai requires python 3.10 as minimum
|
||||
# def test_crewai_tool():
|
||||
# docs_tool = CrewaiTool(
|
||||
# name='direcotry_read_tool',
|
||||
# description='use this to find files for you.',
|
||||
# tool=FileReadTool(),
|
||||
# )
|
||||
# function_decl = docs_tool.get_declaration()
|
||||
# assert function_decl.name == 'direcotry_read_tool'
|
||||
# assert function_decl.parameters.type == 'OBJECT'
|
||||
# assert function_decl.parameters.properties['file_path'].type == 'STRING'
|
||||
304
tests/unittests/utils.py
Normal file
304
tests/unittests/utils.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Union
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.live_request_queue import LiveRequestQueue
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.artifacts import InMemoryArtifactService
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from google.adk.models.base_llm import BaseLlm
|
||||
from google.adk.models.base_llm_connection import BaseLlmConnection
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.runners import InMemoryRunner as AfInMemoryRunner
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class UserContent(types.Content):
|
||||
|
||||
def __init__(self, text_or_part: str):
|
||||
parts = [
|
||||
types.Part.from_text(text=text_or_part)
|
||||
if isinstance(text_or_part, str)
|
||||
else text_or_part
|
||||
]
|
||||
super().__init__(role='user', parts=parts)
|
||||
|
||||
|
||||
class ModelContent(types.Content):
|
||||
|
||||
def __init__(self, parts: list[types.Part]):
|
||||
super().__init__(role='model', parts=parts)
|
||||
|
||||
|
||||
def create_invocation_context(agent: Agent, user_content: str = ''):
|
||||
invocation_id = 'test_id'
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
invocation_context = InvocationContext(
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
memory_service=memory_service,
|
||||
invocation_id=invocation_id,
|
||||
agent=agent,
|
||||
session=session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
),
|
||||
user_content=types.Content(
|
||||
role='user', parts=[types.Part.from_text(text=user_content)]
|
||||
),
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
if user_content:
|
||||
append_user_content(
|
||||
invocation_context, [types.Part.from_text(text=user_content)]
|
||||
)
|
||||
return invocation_context
|
||||
|
||||
|
||||
def append_user_content(
|
||||
invocation_context: InvocationContext, parts: list[types.Part]
|
||||
) -> Event:
|
||||
session = invocation_context.session
|
||||
event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=parts),
|
||||
)
|
||||
session.events.append(event)
|
||||
return event
|
||||
|
||||
|
||||
# Extracts the contents from the events and transform them into a list of
|
||||
# (author, simplified_content) tuples.
|
||||
def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
|
||||
return [(event.author, simplify_content(event.content)) for event in events]
|
||||
|
||||
|
||||
# Simplifies the contents into a list of (author, simplified_content) tuples.
|
||||
def simplify_contents(contents: list[types.Content]) -> list[(str, types.Part)]:
|
||||
return [(content.role, simplify_content(content)) for content in contents]
|
||||
|
||||
|
||||
# Simplifies the content so it's easier to assert.
|
||||
# - If there is only one part, return part
|
||||
# - If the only part is pure text, return stripped_text
|
||||
# - If there are multiple parts, return parts
|
||||
# - remove function_call_id if it exists
|
||||
def simplify_content(
|
||||
content: types.Content,
|
||||
) -> Union[str, types.Part, list[types.Part]]:
|
||||
for part in content.parts:
|
||||
if part.function_call and part.function_call.id:
|
||||
part.function_call.id = None
|
||||
if part.function_response and part.function_response.id:
|
||||
part.function_response.id = None
|
||||
if len(content.parts) == 1:
|
||||
if content.parts[0].text:
|
||||
return content.parts[0].text.strip()
|
||||
else:
|
||||
return content.parts[0]
|
||||
return content.parts
|
||||
|
||||
|
||||
def get_user_content(message: types.ContentUnion) -> types.Content:
|
||||
return message if isinstance(message, types.Content) else UserContent(message)
|
||||
|
||||
|
||||
class TestInMemoryRunner(AfInMemoryRunner):
|
||||
"""InMemoryRunner that is tailored for tests, features async run method.
|
||||
|
||||
app_name is hardcoded as InMemoryRunner in the parent class.
|
||||
"""
|
||||
|
||||
async def run_async_with_new_session(
|
||||
self, new_message: types.ContentUnion
|
||||
) -> list[Event]:
|
||||
|
||||
session = self.session_service.create_session(
|
||||
app_name='InMemoryRunner', user_id='test_user'
|
||||
)
|
||||
collected_events = []
|
||||
|
||||
async for event in self.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=get_user_content(new_message),
|
||||
):
|
||||
collected_events.append(event)
|
||||
|
||||
return collected_events
|
||||
|
||||
|
||||
class InMemoryRunner:
|
||||
"""InMemoryRunner that is tailored for tests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_agent: Union[Agent, LlmAgent],
|
||||
response_modalities: list[str] = None,
|
||||
):
|
||||
self.root_agent = root_agent
|
||||
self.runner = Runner(
|
||||
app_name='test_app',
|
||||
agent=root_agent,
|
||||
artifact_service=InMemoryArtifactService(),
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
self.session_id = self.runner.session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
).id
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
return self.runner.session_service.get_session(
|
||||
app_name='test_app', user_id='test_user', session_id=self.session_id
|
||||
)
|
||||
|
||||
def run(self, new_message: types.ContentUnion) -> list[Event]:
|
||||
return list(
|
||||
self.runner.run(
|
||||
user_id=self.session.user_id,
|
||||
session_id=self.session.id,
|
||||
new_message=get_user_content(new_message),
|
||||
)
|
||||
)
|
||||
|
||||
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
|
||||
collected_responses = []
|
||||
|
||||
async def consume_responses():
|
||||
run_res = self.runner.run_live(
|
||||
session=self.session,
|
||||
live_request_queue=live_request_queue,
|
||||
)
|
||||
|
||||
async for response in run_res:
|
||||
collected_responses.append(response)
|
||||
# When we have enough response, we should return
|
||||
if len(collected_responses) >= 1:
|
||||
return
|
||||
|
||||
try:
|
||||
asyncio.run(consume_responses())
|
||||
except asyncio.TimeoutError:
|
||||
print('Returning any partial results collected so far.')
|
||||
|
||||
return collected_responses
|
||||
|
||||
|
||||
class MockModel(BaseLlm):
|
||||
model: str = 'mock'
|
||||
|
||||
requests: list[LlmRequest] = []
|
||||
responses: list[LlmResponse]
|
||||
response_index: int = -1
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
responses: Union[
|
||||
list[types.Part], list[LlmResponse], list[str], list[list[types.Part]]
|
||||
],
|
||||
):
|
||||
if not responses:
|
||||
return cls(responses=[])
|
||||
elif isinstance(responses[0], LlmResponse):
|
||||
# responses is list[LlmResponse]
|
||||
return cls(responses=responses)
|
||||
else:
|
||||
responses = [
|
||||
LlmResponse(content=ModelContent(item))
|
||||
if isinstance(item, list) and isinstance(item[0], types.Part)
|
||||
# responses is list[list[Part]]
|
||||
else LlmResponse(
|
||||
content=ModelContent(
|
||||
# responses is list[str] or list[Part]
|
||||
[Part(text=item) if isinstance(item, str) else item]
|
||||
)
|
||||
)
|
||||
for item in responses
|
||||
if item
|
||||
]
|
||||
|
||||
return cls(responses=responses)
|
||||
|
||||
@staticmethod
|
||||
def supported_models() -> list[str]:
|
||||
return ['mock']
|
||||
|
||||
def generate_content(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> Generator[LlmResponse, None, None]:
|
||||
# Increasement of the index has to happen before the yield.
|
||||
self.response_index += 1
|
||||
self.requests.append(llm_request)
|
||||
# yield LlmResponse(content=self.responses[self.response_index])
|
||||
yield self.responses[self.response_index]
|
||||
|
||||
@override
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
# Increasement of the index has to happen before the yield.
|
||||
self.response_index += 1
|
||||
self.requests.append(llm_request)
|
||||
yield self.responses[self.response_index]
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Creates a live connection to the LLM."""
|
||||
yield MockLlmConnection(self.responses)
|
||||
|
||||
|
||||
class MockLlmConnection(BaseLlmConnection):
|
||||
|
||||
def __init__(self, llm_responses: list[LlmResponse]):
|
||||
self.llm_responses = llm_responses
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
pass
|
||||
|
||||
async def send_content(self, content: types.Content):
|
||||
pass
|
||||
|
||||
async def send(self, data):
|
||||
pass
|
||||
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
pass
|
||||
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Yield each of the pre-defined LlmResponses."""
|
||||
for response in self.llm_responses:
|
||||
yield response
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user