mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package. * Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py * Adds github workflow * minor fix in lite_llm.py for python 3.9. * format pyproject.toml
This commit is contained in:
14
tests/unittests/flows/llm_flows/__init__.py
Normal file
14
tests/unittests/flows/llm_flows/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
142
tests/unittests/flows/llm_flows/_test_examples.py
Normal file
142
tests/unittests/flows/llm_flows/_test_examples.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: delete and rewrite unit tests
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.examples import BaseExampleProvider
|
||||
from google.adk.examples import Example
|
||||
from google.adk.flows.llm_flows import examples
|
||||
from google.adk.models.base_llm import LlmRequest
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_examples():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content=""
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_examples():
|
||||
example_list = [
|
||||
Example(
|
||||
input=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="test1")],
|
||||
),
|
||||
output=[
|
||||
types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="response1")],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
examples=example_list,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert (
|
||||
request.config.system_instruction
|
||||
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
|
||||
" queries and model responses using the available tools.\n\nEXAMPLE"
|
||||
" 1:\nBegin example\n[user]\ntest1\n\n[model]\nresponse1\nEnd"
|
||||
" example\n\nEnd few-shot\nNow, try to follow these examples and"
|
||||
" complete the following conversation\n<EXAMPLES>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_base_example_provider():
|
||||
class TestExampleProvider(BaseExampleProvider):
|
||||
|
||||
def get_examples(self, query: str) -> list[Example]:
|
||||
if query == "test":
|
||||
return [
|
||||
Example(
|
||||
input=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="test")],
|
||||
),
|
||||
output=[
|
||||
types.Content(
|
||||
role="model",
|
||||
parts=[types.Part.from_text(text="response1")],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
provider = TestExampleProvider()
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
examples=provider,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(
|
||||
agent=agent, user_content="test"
|
||||
)
|
||||
|
||||
async for _ in examples.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert (
|
||||
request.config.system_instruction
|
||||
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
|
||||
" queries and model responses using the available tools.\n\nEXAMPLE"
|
||||
" 1:\nBegin example\n[user]\ntest\n\n[model]\nresponse1\nEnd"
|
||||
" example\n\nEnd few-shot\nNow, try to follow these examples and"
|
||||
" complete the following conversation\n<EXAMPLES>"
|
||||
)
|
||||
311
tests/unittests/flows/llm_flows/test_agent_transfer.py
Normal file
311
tests/unittests/flows/llm_flows/test_agent_transfer.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.agents.loop_agent import LoopAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.tools import exit_loop
|
||||
from google.genai.types import Part
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def transfer_call_part(agent_name: str) -> Part:
|
||||
return Part.from_function_call(
|
||||
name='transfer_to_agent', args={'agent_name': agent_name}
|
||||
)
|
||||
|
||||
|
||||
TRANSFER_RESPONSE_PART = Part.from_function_response(
|
||||
name='transfer_to_agent', response={}
|
||||
)
|
||||
|
||||
|
||||
def test_auto_to_auto():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (auto)
|
||||
sub_agent_1 = Agent(name='sub_agent_1', model=mockModel)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', 'response1'),
|
||||
]
|
||||
|
||||
# sub_agent_1 should still be the current agent.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('sub_agent_1', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_single():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (single)
|
||||
sub_agent_1 = Agent(
|
||||
name='sub_agent_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', 'response1'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent, becaues sub_agent_1 is single.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_auto_to_single():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 transfers to sub_agent_1_1.
|
||||
transfer_call_part('sub_agent_1_1'),
|
||||
'response1',
|
||||
'response2',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (auto) - sub_agent_1_1 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = Agent(
|
||||
name='sub_agent_1', model=mockModel, sub_agents=[sub_agent_1_1]
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1', transfer_call_part('sub_agent_1_1')),
|
||||
('sub_agent_1', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
]
|
||||
|
||||
# sub_agent_1 should still be the current agent. sub_agent_1_1 is single so it should
|
||||
# not be the current agent, otherwise the conversation will be tied to
|
||||
# sub_agent_1_1 forever.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('sub_agent_1', 'response2'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_sequential():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (sequential) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = SequentialAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', 'response2'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is sequential.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response3'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_sequential_to_auto():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
transfer_call_part('sub_agent_1_2_1'),
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (seq) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (auto) - sub_agent_1_2_1 (auto)
|
||||
# \ sub_agent_1_3 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2_1 = Agent(name='sub_agent_1_2_1', model=mockModel)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1_2_1],
|
||||
)
|
||||
sub_agent_1_3 = Agent(
|
||||
name='sub_agent_1_3',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1 = SequentialAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2, sub_agent_1_3],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', transfer_call_part('sub_agent_1_2_1')),
|
||||
('sub_agent_1_2', TRANSFER_RESPONSE_PART),
|
||||
('sub_agent_1_2_1', 'response2'),
|
||||
('sub_agent_1_3', 'response3'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is sequential.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response4'),
|
||||
]
|
||||
|
||||
|
||||
def test_auto_to_loop():
|
||||
response = [
|
||||
transfer_call_part('sub_agent_1'),
|
||||
# sub_agent_1 responds directly instead of transfering.
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
Part.from_function_call(name='exit_loop', args={}),
|
||||
'response4',
|
||||
'response5',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
# root (auto) - sub_agent_1 (loop) - sub_agent_1_1 (single)
|
||||
# \ sub_agent_1_2 (single)
|
||||
sub_agent_1_1 = Agent(
|
||||
name='sub_agent_1_1',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
sub_agent_1_2 = Agent(
|
||||
name='sub_agent_1_2',
|
||||
model=mockModel,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
tools=[exit_loop],
|
||||
)
|
||||
sub_agent_1 = LoopAgent(
|
||||
name='sub_agent_1',
|
||||
sub_agents=[sub_agent_1_1, sub_agent_1_2],
|
||||
)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
sub_agents=[sub_agent_1],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
# Asserts the transfer.
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
# Transfers to sub_agent_1.
|
||||
('root_agent', transfer_call_part('sub_agent_1')),
|
||||
('root_agent', TRANSFER_RESPONSE_PART),
|
||||
# Loops.
|
||||
('sub_agent_1_1', 'response1'),
|
||||
('sub_agent_1_2', 'response2'),
|
||||
('sub_agent_1_1', 'response3'),
|
||||
# Exits.
|
||||
('sub_agent_1_2', Part.from_function_call(name='exit_loop', args={})),
|
||||
(
|
||||
'sub_agent_1_2',
|
||||
Part.from_function_response(name='exit_loop', response={}),
|
||||
),
|
||||
# root_agent summarizes.
|
||||
('root_agent', 'response4'),
|
||||
]
|
||||
|
||||
# root_agent should still be the current agent because sub_agent_1 is loop.
|
||||
assert utils.simplify_events(runner.run('test2')) == [
|
||||
('root_agent', 'response5'),
|
||||
]
|
||||
244
tests/unittests/flows/llm_flows/test_functions_long_running.py
Normal file
244
tests/unittests/flows/llm_flows/test_functions_long_running.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.long_running_tool import LongRunningFunctionTool
|
||||
from google.genai.types import Part
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_async_function():
|
||||
responses = [
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int, tool_context: ToolContext) -> int:
|
||||
nonlocal function_called
|
||||
|
||||
function_called += 1
|
||||
return {'status': 'pending'}
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[LongRunningFunctionTool(func=increase_by_one)],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 2
|
||||
# 1 item: user content
|
||||
assert mockModel.requests[0].contents == [
|
||||
utils.UserContent('test1'),
|
||||
]
|
||||
increase_by_one_call = Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
pending_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'pending'}
|
||||
)
|
||||
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', pending_response),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(events) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'pending'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
assert events[0].long_running_tool_ids
|
||||
|
||||
# Updates with another pending progress.
|
||||
still_waiting_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'still waiting'}
|
||||
)
|
||||
events = runner.run(utils.UserContent(still_waiting_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 3
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', still_waiting_response),
|
||||
]
|
||||
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response2')]
|
||||
|
||||
# Calls when the result is ready.
|
||||
result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
events = runner.run(utils.UserContent(result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 4
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response3')]
|
||||
|
||||
# Calls when the result is ready. Here we still accept the result and do
|
||||
# another summarization. Whether this is the right behavior is TBD.
|
||||
another_result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 3}
|
||||
)
|
||||
events = runner.run(utils.UserContent(another_result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 5
|
||||
assert utils.simplify_contents(mockModel.requests[4].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', another_result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response4')]
|
||||
|
||||
# At the end, function_called should still be 1.
|
||||
assert function_called == 1
|
||||
|
||||
|
||||
def test_async_function_with_none_response():
|
||||
responses = [
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int, tool_context: ToolContext) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return 'pending'
|
||||
|
||||
# Calls the first time.
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
tools=[LongRunningFunctionTool(func=increase_by_one)],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test1')
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 2
|
||||
# 1 item: user content
|
||||
assert mockModel.requests[0].contents == [
|
||||
utils.UserContent('test1'),
|
||||
]
|
||||
increase_by_one_call = Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
(
|
||||
'user',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 'pending'}
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
# Asserts the responses.
|
||||
assert utils.simplify_events(events) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 'pending'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Updates with another pending progress.
|
||||
still_waiting_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'status': 'still waiting'}
|
||||
)
|
||||
events = runner.run(utils.UserContent(still_waiting_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 3
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', still_waiting_response),
|
||||
]
|
||||
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response2')]
|
||||
|
||||
# Calls when the result is ready.
|
||||
result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
events = runner.run(utils.UserContent(result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 4
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response3')]
|
||||
|
||||
# Calls when the result is ready. Here we still accept the result and do
|
||||
# another summarization. Whether this is the right behavior is TBD.
|
||||
another_result_response = Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 3}
|
||||
)
|
||||
events = runner.run(utils.UserContent(another_result_response))
|
||||
# We have one new request.
|
||||
assert len(mockModel.requests) == 5
|
||||
assert utils.simplify_contents(mockModel.requests[4].contents) == [
|
||||
('user', 'test1'),
|
||||
('model', increase_by_one_call),
|
||||
('user', another_result_response),
|
||||
]
|
||||
assert utils.simplify_events(events) == [('root_agent', 'response4')]
|
||||
|
||||
# At the end, function_called should still be 1.
|
||||
assert function_called == 1
|
||||
346
tests/unittests/flows/llm_flows/test_functions_request_euc.py
Normal file
346
tests/unittests/flows/llm_flows/test_functions_request_euc.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.auth import AuthConfig
|
||||
from google.adk.auth import AuthCredential
|
||||
from google.adk.auth import AuthCredentialTypes
|
||||
from google.adk.auth import OAuth2Auth
|
||||
from google.adk.flows.llm_flows import functions
|
||||
from google.adk.tools import AuthToolArguments
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def function_call(function_call_id, name, args: dict[str, Any]) -> types.Part:
|
||||
part = types.Part.from_function_call(name=name, args=args)
|
||||
part.function_call.id = function_call_id
|
||||
return part
|
||||
|
||||
|
||||
def test_function_request_euc():
|
||||
responses = [
|
||||
[
|
||||
types.Part.from_function_call(name='call_external_api1', args={}),
|
||||
types.Part.from_function_call(name='call_external_api2', args={}),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response1'),
|
||||
],
|
||||
]
|
||||
|
||||
auth_config1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_config2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def call_external_api1(tool_context: ToolContext) -> Optional[int]:
|
||||
tool_context.request_credential(auth_config1)
|
||||
|
||||
def call_external_api2(tool_context: ToolContext) -> Optional[int]:
|
||||
tool_context.request_credential(auth_config2)
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[call_external_api1, call_external_api2],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test')
|
||||
assert events[0].content.parts[0].function_call is not None
|
||||
assert events[0].content.parts[1].function_call is not None
|
||||
auth_configs = list(events[2].actions.requested_auth_configs.values())
|
||||
exchanged_auth_config1 = auth_configs[0]
|
||||
exchanged_auth_config2 = auth_configs[1]
|
||||
assert exchanged_auth_config1.auth_scheme == auth_config1.auth_scheme
|
||||
assert (
|
||||
exchanged_auth_config1.raw_auth_credential
|
||||
== auth_config1.raw_auth_credential
|
||||
)
|
||||
assert (
|
||||
exchanged_auth_config1.exchanged_auth_credential.oauth2.auth_uri
|
||||
is not None
|
||||
)
|
||||
assert exchanged_auth_config2.auth_scheme == auth_config2.auth_scheme
|
||||
assert (
|
||||
exchanged_auth_config2.raw_auth_credential
|
||||
== auth_config2.raw_auth_credential
|
||||
)
|
||||
assert (
|
||||
exchanged_auth_config2.exchanged_auth_credential.oauth2.auth_uri
|
||||
is not None
|
||||
)
|
||||
function_call_ids = list(events[2].actions.requested_auth_configs.keys())
|
||||
|
||||
for idx, part in enumerate(events[1].content.parts):
|
||||
reqeust_euc_function_call = part.function_call
|
||||
assert reqeust_euc_function_call is not None
|
||||
assert (
|
||||
reqeust_euc_function_call.name
|
||||
== functions.REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
)
|
||||
args = AuthToolArguments.model_validate(reqeust_euc_function_call.args)
|
||||
|
||||
assert args.function_call_id == function_call_ids[idx]
|
||||
args.auth_config.auth_scheme.model_extra.clear()
|
||||
assert args.auth_config.auth_scheme == auth_configs[idx].auth_scheme
|
||||
assert (
|
||||
args.auth_config.raw_auth_credential
|
||||
== auth_configs[idx].raw_auth_credential
|
||||
)
|
||||
|
||||
|
||||
def test_function_get_auth_response():
|
||||
id_1 = 'id_1'
|
||||
id_2 = 'id_2'
|
||||
responses = [
|
||||
[
|
||||
function_call(id_1, 'call_external_api1', {}),
|
||||
function_call(id_2, 'call_external_api2', {}),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response1'),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response2'),
|
||||
],
|
||||
]
|
||||
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
function_invoked = 0
|
||||
|
||||
auth_config1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_config2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
auth_response1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
token={'access_token': 'token1'},
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_response2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
token={'access_token': 'token2'},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def call_external_api1(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config1)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config1)
|
||||
return
|
||||
assert auth_response == auth_response1.exchanged_auth_credential
|
||||
return 1
|
||||
|
||||
def call_external_api2(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config2)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config2)
|
||||
return
|
||||
assert auth_response == auth_response2.exchanged_auth_credential
|
||||
return 2
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[call_external_api1, call_external_api2],
|
||||
)
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
runner.run('test')
|
||||
request_euc_function_call_event = runner.session.events[-3]
|
||||
function_response1 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[0].function_call.name,
|
||||
response=auth_response1.model_dump(),
|
||||
)
|
||||
function_response1.id = request_euc_function_call_event.content.parts[
|
||||
0
|
||||
].function_call.id
|
||||
|
||||
function_response2 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[1].function_call.name,
|
||||
response=auth_response2.model_dump(),
|
||||
)
|
||||
function_response2.id = request_euc_function_call_event.content.parts[
|
||||
1
|
||||
].function_call.id
|
||||
runner.run(
|
||||
new_message=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert function_invoked == 4
|
||||
reqeust = mock_model.requests[-1]
|
||||
content = reqeust.contents[-1]
|
||||
parts = content.parts
|
||||
assert len(parts) == 2
|
||||
assert parts[0].function_response.name == 'call_external_api1'
|
||||
assert parts[0].function_response.response == {'result': 1}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': 2}
|
||||
93
tests/unittests/flows/llm_flows/test_functions_sequential.py
Normal file
93
tests/unittests/flows/llm_flows/test_functions_sequential.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.genai import types
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def function_call(args: dict[str, Any]) -> types.Part:
|
||||
return types.Part.from_function_call(name='increase_by_one', args=args)
|
||||
|
||||
|
||||
def function_response(response: dict[str, Any]) -> types.Part:
|
||||
return types.Part.from_function_response(
|
||||
name='increase_by_one', response=response
|
||||
)
|
||||
|
||||
|
||||
def test_sequential_calls():
|
||||
responses = [
|
||||
function_call({'x': 1}),
|
||||
function_call({'x': 2}),
|
||||
function_call({'x': 3}),
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=responses)
|
||||
function_called = 0
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mockModel, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
result = utils.simplify_events(runner.run('test'))
|
||||
assert result == [
|
||||
('root_agent', function_call({'x': 1})),
|
||||
('root_agent', function_response({'result': 2})),
|
||||
('root_agent', function_call({'x': 2})),
|
||||
('root_agent', function_response({'result': 3})),
|
||||
('root_agent', function_call({'x': 3})),
|
||||
('root_agent', function_response({'result': 4})),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert len(mockModel.requests) == 4
|
||||
# 1 item: user content
|
||||
assert utils.simplify_contents(mockModel.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
# 3 items: user content, functaion call / response for the 1st call
|
||||
assert utils.simplify_contents(mockModel.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
]
|
||||
# 5 items: user content, functaion call / response for two calls
|
||||
assert utils.simplify_contents(mockModel.requests[2].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
('model', function_call({'x': 2})),
|
||||
('user', function_response({'result': 3})),
|
||||
]
|
||||
# 7 items: user content, functaion call / response for three calls
|
||||
assert utils.simplify_contents(mockModel.requests[3].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call({'x': 1})),
|
||||
('user', function_response({'result': 2})),
|
||||
('model', function_call({'x': 2})),
|
||||
('user', function_response({'result': 3})),
|
||||
('model', function_call({'x': 3})),
|
||||
('user', function_response({'result': 4})),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
258
tests/unittests/flows/llm_flows/test_functions_simple.py
Normal file
258
tests/unittests/flows/llm_flows/test_functions_simple.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_simple_function():
|
||||
function_call_1 = types.Part.from_function_call(
|
||||
name='increase_by_one', args={'x': 1}
|
||||
)
|
||||
function_respones_2 = types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
)
|
||||
responses: list[types.Content] = [
|
||||
function_call_1,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', function_call_1),
|
||||
('root_agent', function_respones_2),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_call_1),
|
||||
('user', function_respones_2),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
|
||||
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
|
||||
]
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two', response={'result': 4}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two_sync', response={'result': 6}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
async def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
async def multiple_by_two(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
def multiple_by_two_sync(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[increase_by_one, multiple_by_two, multiple_by_two_sync],
|
||||
)
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
assert utils.simplify_events(events) == [
|
||||
('root_agent', function_calls),
|
||||
('root_agent', function_responses),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_calls),
|
||||
('user', function_responses),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool():
|
||||
function_calls = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
|
||||
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
|
||||
]
|
||||
function_responses = [
|
||||
types.Part.from_function_response(
|
||||
name='increase_by_one', response={'result': 2}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two', response={'result': 4}
|
||||
),
|
||||
types.Part.from_function_response(
|
||||
name='multiple_by_two_sync', response={'result': 6}
|
||||
),
|
||||
]
|
||||
|
||||
responses: list[types.Content] = [
|
||||
function_calls,
|
||||
'response1',
|
||||
'response2',
|
||||
'response3',
|
||||
'response4',
|
||||
]
|
||||
function_called = 0
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
async def increase_by_one(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x + 1
|
||||
|
||||
async def multiple_by_two(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
def multiple_by_two_sync(x: int) -> int:
|
||||
nonlocal function_called
|
||||
function_called += 1
|
||||
return x * 2
|
||||
|
||||
class TestTool(FunctionTool):
|
||||
|
||||
def __init__(self, func: Callable[..., Any]):
|
||||
super().__init__(func=func)
|
||||
|
||||
wrapped_increase_by_one = TestTool(func=increase_by_one)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[wrapped_increase_by_one, multiple_by_two, multiple_by_two_sync],
|
||||
)
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
events = await runner.run_async_with_new_session('test')
|
||||
assert utils.simplify_events(events) == [
|
||||
('root_agent', function_calls),
|
||||
('root_agent', function_responses),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
# Asserts the requests.
|
||||
assert utils.simplify_contents(mock_model.requests[0].contents) == [
|
||||
('user', 'test')
|
||||
]
|
||||
assert utils.simplify_contents(mock_model.requests[1].contents) == [
|
||||
('user', 'test'),
|
||||
('model', function_calls),
|
||||
('user', function_responses),
|
||||
]
|
||||
|
||||
# Asserts the function calls.
|
||||
assert function_called == 3
|
||||
|
||||
|
||||
def test_update_state():
|
||||
mock_model = utils.MockModel.create(
|
||||
responses=[
|
||||
types.Part.from_function_call(name='update_state', args={}),
|
||||
'response1',
|
||||
]
|
||||
)
|
||||
|
||||
def update_state(tool_context: ToolContext):
|
||||
tool_context.state['x'] = 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[update_state])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
runner.run('test')
|
||||
assert runner.session.state['x'] == 1
|
||||
|
||||
|
||||
def test_function_call_id():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
|
||||
def increase_by_one(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
events = runner.run('test')
|
||||
for reqeust in mock_model.requests:
|
||||
for content in reqeust.contents:
|
||||
for part in content.parts:
|
||||
if part.function_call:
|
||||
assert part.function_call.id is None
|
||||
if part.function_response:
|
||||
assert part.function_response.id is None
|
||||
assert events[0].content.parts[0].function_call.id.startswith('adk-')
|
||||
assert events[1].content.parts[0].function_response.id.startswith('adk-')
|
||||
66
tests/unittests/flows/llm_flows/test_identity.py
Normal file
66
tests/unittests/flows/llm_flows/test_identity.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.flows.llm_flows import identity
|
||||
from google.adk.models import LlmRequest
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_description():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(model="gemini-1.5-flash", name="agent")
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""You are an agent. Your internal name is "agent"."""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_description():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
description="test description",
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
|
||||
async for _ in identity.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == "\n\n".join([
|
||||
'You are an agent. Your internal name is "agent".',
|
||||
' The description about you is "test description"',
|
||||
])
|
||||
164
tests/unittests/flows/llm_flows/test_instructions.py
Normal file
164
tests/unittests/flows/llm_flows/test_instructions.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.flows.llm_flows import instructions
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_system_instruction():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=("""Use the echo_info tool to echo { customerId }, \
|
||||
{{customer_int }, { non-identifier-float}}, \
|
||||
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""Use the echo_info tool to echo 1234567890, 30, \
|
||||
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_system_instruction():
|
||||
def build_function_instruction(readonly_context: ReadonlyContext) -> str:
|
||||
return (
|
||||
"This is the function agent instruction for invocation:"
|
||||
f" {readonly_context.invocation_id}."
|
||||
)
|
||||
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=build_function_instruction,
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"This is the function agent instruction for invocation: test_id."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_system_instruction():
|
||||
sub_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="sub_agent",
|
||||
instruction="This is the sub agent instruction.",
|
||||
)
|
||||
root_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="root_agent",
|
||||
global_instruction="This is the global instruction.",
|
||||
sub_agents=[sub_agent],
|
||||
)
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={"customerId": "1234567890", "customer_int": 30},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"This is the global instruction.\n\nThis is the sub agent instruction."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_system_instruction_with_namespace():
|
||||
request = LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
config=types.GenerateContentConfig(system_instruction=""),
|
||||
)
|
||||
agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
instruction=(
|
||||
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
|
||||
),
|
||||
)
|
||||
invocation_context = utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
id="test_id",
|
||||
state={
|
||||
"customerId": "1234567890",
|
||||
"app:key": "app_value",
|
||||
"user:key": "user_value",
|
||||
},
|
||||
)
|
||||
|
||||
async for _ in instructions.request_processor.run_async(
|
||||
invocation_context,
|
||||
request,
|
||||
):
|
||||
pass
|
||||
|
||||
assert request.config.system_instruction == (
|
||||
"""Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
|
||||
)
|
||||
142
tests/unittests/flows/llm_flows/test_model_callbacks.py
Normal file
142
tests/unittests/flows/llm_flows/test_model_callbacks.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
class MockBeforeModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class MockAfterModelCallback(BaseModel):
|
||||
mock_response: str
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=utils.ModelContent(
|
||||
[types.Part.from_text(text=self.mock_response)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||
pass
|
||||
|
||||
|
||||
def test_before_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'model_response'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_model_callback_end():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_model_callback=MockBeforeModelCallback(
|
||||
mock_response='before_model_callback',
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'before_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_model_callback():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=MockAfterModelCallback(
|
||||
mock_response='after_model_callback'
|
||||
),
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', 'after_model_callback'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_model_callback_noop():
|
||||
responses = ['model_response']
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_model_callback=noop_callback,
|
||||
)
|
||||
|
||||
runner = utils.TestInMemoryRunner(agent)
|
||||
assert utils.simplify_events(
|
||||
await runner.run_async_with_new_session('test')
|
||||
) == [('root_agent', 'model_response')]
|
||||
46
tests/unittests/flows/llm_flows/test_other_configs.py
Normal file
46
tests/unittests/flows/llm_flows/test_other_configs.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def test_output_schema():
|
||||
class CustomOutput(BaseModel):
|
||||
custom_field: str
|
||||
|
||||
response = [
|
||||
'response1',
|
||||
]
|
||||
mockModel = utils.MockModel.create(responses=response)
|
||||
root_agent = Agent(
|
||||
name='root_agent',
|
||||
model=mockModel,
|
||||
output_schema=CustomOutput,
|
||||
disallow_transfer_to_parent=True,
|
||||
disallow_transfer_to_peers=True,
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(root_agent)
|
||||
|
||||
assert utils.simplify_events(runner.run('test1')) == [
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
assert len(mockModel.requests) == 1
|
||||
assert mockModel.requests[0].config.response_schema == CustomOutput
|
||||
assert mockModel.requests[0].config.response_mime_type == 'application/json'
|
||||
269
tests/unittests/flows/llm_flows/test_tool_callbacks.py
Normal file
269
tests/unittests/flows/llm_flows/test_tool_callbacks.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import BaseTool
|
||||
from google.adk.tools import ToolContext
|
||||
from google.genai import types
|
||||
from google.genai.types import Part
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ... import utils
|
||||
|
||||
|
||||
def simple_function(input_str: str) -> str:
|
||||
return {'result': input_str}
|
||||
|
||||
|
||||
class MockBeforeToolCallback(BaseModel):
|
||||
mock_response: dict[str, object]
|
||||
modify_tool_request: bool = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> dict[str, object]:
|
||||
if self.modify_tool_request:
|
||||
args['input_str'] = 'modified_input'
|
||||
return None
|
||||
return self.mock_response
|
||||
|
||||
|
||||
class MockAfterToolCallback(BaseModel):
|
||||
mock_response: dict[str, object]
|
||||
modify_tool_request: bool = False
|
||||
modify_tool_response: bool = False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: dict[str, Any] = None,
|
||||
) -> dict[str, object]:
|
||||
if self.modify_tool_request:
|
||||
args['input_str'] = 'modified_input'
|
||||
return None
|
||||
if self.modify_tool_response:
|
||||
tool_response['result'] = 'modified_output'
|
||||
return tool_response
|
||||
return self.mock_response
|
||||
|
||||
|
||||
def noop_callback(
|
||||
**kwargs,
|
||||
) -> dict[str, object]:
|
||||
pass
|
||||
|
||||
|
||||
def test_before_tool_callback():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='simple_function', args={}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=MockBeforeToolCallback(
|
||||
mock_response={'test': 'before_tool_callback'}
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', Part.from_function_call(name='simple_function', args={})),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function', response={'test': 'before_tool_callback'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_tool_callback_noop():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=noop_callback,
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'simple_function_call'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_before_tool_callback_modify_tool_request():
|
||||
responses = [
|
||||
types.Part.from_function_call(name='simple_function', args={}),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
before_tool_callback=MockBeforeToolCallback(
|
||||
mock_response={'test': 'before_tool_callback'},
|
||||
modify_tool_request=True,
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
('root_agent', Part.from_function_call(name='simple_function', args={})),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'modified_input'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=MockAfterToolCallback(
|
||||
mock_response={'test': 'after_tool_callback'}
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function', response={'test': 'after_tool_callback'}
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback_noop():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=noop_callback,
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'simple_function_call'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
|
||||
|
||||
def test_after_tool_callback_modify_tool_response():
|
||||
responses = [
|
||||
types.Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
'response1',
|
||||
]
|
||||
mock_model = utils.MockModel.create(responses=responses)
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
after_tool_callback=MockAfterToolCallback(
|
||||
mock_response={'result': 'after_tool_callback'},
|
||||
modify_tool_response=True,
|
||||
),
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
runner = utils.InMemoryRunner(agent)
|
||||
assert utils.simplify_events(runner.run('test')) == [
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_call(
|
||||
name='simple_function', args={'input_str': 'simple_function_call'}
|
||||
),
|
||||
),
|
||||
(
|
||||
'root_agent',
|
||||
Part.from_function_response(
|
||||
name='simple_function',
|
||||
response={'result': 'modified_output'},
|
||||
),
|
||||
),
|
||||
('root_agent', 'response1'),
|
||||
]
|
||||
Reference in New Issue
Block a user