Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,142 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: delete and rewrite unit tests
from google.adk.agents import Agent
from google.adk.examples import BaseExampleProvider
from google.adk.examples import Example
from google.adk.flows.llm_flows import examples
from google.adk.models.base_llm import LlmRequest
from google.genai import types
import pytest
from ... import utils
@pytest.mark.asyncio
async def test_no_examples():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
invocation_context = utils.create_invocation_context(
agent=agent, user_content=""
)
async for _ in examples.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == ""
@pytest.mark.asyncio
async def test_agent_examples():
example_list = [
Example(
input=types.Content(
role="user",
parts=[types.Part.from_text(text="test1")],
),
output=[
types.Content(
role="model",
parts=[types.Part.from_text(text="response1")],
),
],
)
]
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
examples=example_list,
)
invocation_context = utils.create_invocation_context(
agent=agent, user_content="test"
)
async for _ in examples.request_processor.run_async(
invocation_context,
request,
):
pass
assert (
request.config.system_instruction
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
" queries and model responses using the available tools.\n\nEXAMPLE"
" 1:\nBegin example\n[user]\ntest1\n\n[model]\nresponse1\nEnd"
" example\n\nEnd few-shot\nNow, try to follow these examples and"
" complete the following conversation\n<EXAMPLES>"
)
@pytest.mark.asyncio
async def test_agent_base_example_provider():
class TestExampleProvider(BaseExampleProvider):
def get_examples(self, query: str) -> list[Example]:
if query == "test":
return [
Example(
input=types.Content(
role="user",
parts=[types.Part.from_text(text="test")],
),
output=[
types.Content(
role="model",
parts=[types.Part.from_text(text="response1")],
),
],
)
]
else:
return []
provider = TestExampleProvider()
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
examples=provider,
)
invocation_context = utils.create_invocation_context(
agent=agent, user_content="test"
)
async for _ in examples.request_processor.run_async(
invocation_context,
request,
):
pass
assert (
request.config.system_instruction
== "<EXAMPLES>\nBegin few-shot\nThe following are examples of user"
" queries and model responses using the available tools.\n\nEXAMPLE"
" 1:\nBegin example\n[user]\ntest\n\n[model]\nresponse1\nEnd"
" example\n\nEnd few-shot\nNow, try to follow these examples and"
" complete the following conversation\n<EXAMPLES>"
)

View File

@@ -0,0 +1,311 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents.llm_agent import Agent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.tools import exit_loop
from google.genai.types import Part
from ... import utils
def transfer_call_part(agent_name: str) -> Part:
return Part.from_function_call(
name='transfer_to_agent', args={'agent_name': agent_name}
)
TRANSFER_RESPONSE_PART = Part.from_function_response(
name='transfer_to_agent', response={}
)
def test_auto_to_auto():
response = [
transfer_call_part('sub_agent_1'),
'response1',
'response2',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (auto)
sub_agent_1 = Agent(name='sub_agent_1', model=mockModel)
root_agent = Agent(
name='root_agent',
model=mockModel,
sub_agents=[sub_agent_1],
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the transfer.
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
('sub_agent_1', 'response1'),
]
# sub_agent_1 should still be the current agent.
assert utils.simplify_events(runner.run('test2')) == [
('sub_agent_1', 'response2'),
]
def test_auto_to_single():
response = [
transfer_call_part('sub_agent_1'),
'response1',
'response2',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (single)
sub_agent_1 = Agent(
name='sub_agent_1',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
root_agent = Agent(
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the responses.
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
('sub_agent_1', 'response1'),
]
# root_agent should still be the current agent, becaues sub_agent_1 is single.
assert utils.simplify_events(runner.run('test2')) == [
('root_agent', 'response2'),
]
def test_auto_to_auto_to_single():
response = [
transfer_call_part('sub_agent_1'),
# sub_agent_1 transfers to sub_agent_1_1.
transfer_call_part('sub_agent_1_1'),
'response1',
'response2',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (auto) - sub_agent_1_1 (single)
sub_agent_1_1 = Agent(
name='sub_agent_1_1',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1 = Agent(
name='sub_agent_1', model=mockModel, sub_agents=[sub_agent_1_1]
)
root_agent = Agent(
name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the responses.
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
('sub_agent_1', transfer_call_part('sub_agent_1_1')),
('sub_agent_1', TRANSFER_RESPONSE_PART),
('sub_agent_1_1', 'response1'),
]
# sub_agent_1 should still be the current agent. sub_agent_1_1 is single so it should
# not be the current agent, otherwise the conversation will be tied to
# sub_agent_1_1 forever.
assert utils.simplify_events(runner.run('test2')) == [
('sub_agent_1', 'response2'),
]
def test_auto_to_sequential():
response = [
transfer_call_part('sub_agent_1'),
# sub_agent_1 responds directly instead of transfering.
'response1',
'response2',
'response3',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (sequential) - sub_agent_1_1 (single)
# \ sub_agent_1_2 (single)
sub_agent_1_1 = Agent(
name='sub_agent_1_1',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1_2 = Agent(
name='sub_agent_1_2',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1 = SequentialAgent(
name='sub_agent_1',
sub_agents=[sub_agent_1_1, sub_agent_1_2],
)
root_agent = Agent(
name='root_agent',
model=mockModel,
sub_agents=[sub_agent_1],
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the transfer.
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
('sub_agent_1_1', 'response1'),
('sub_agent_1_2', 'response2'),
]
# root_agent should still be the current agent because sub_agent_1 is sequential.
assert utils.simplify_events(runner.run('test2')) == [
('root_agent', 'response3'),
]
def test_auto_to_sequential_to_auto():
response = [
transfer_call_part('sub_agent_1'),
# sub_agent_1 responds directly instead of transfering.
'response1',
transfer_call_part('sub_agent_1_2_1'),
'response2',
'response3',
'response4',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (seq) - sub_agent_1_1 (single)
# \ sub_agent_1_2 (auto) - sub_agent_1_2_1 (auto)
# \ sub_agent_1_3 (single)
sub_agent_1_1 = Agent(
name='sub_agent_1_1',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1_2_1 = Agent(name='sub_agent_1_2_1', model=mockModel)
sub_agent_1_2 = Agent(
name='sub_agent_1_2',
model=mockModel,
sub_agents=[sub_agent_1_2_1],
)
sub_agent_1_3 = Agent(
name='sub_agent_1_3',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1 = SequentialAgent(
name='sub_agent_1',
sub_agents=[sub_agent_1_1, sub_agent_1_2, sub_agent_1_3],
)
root_agent = Agent(
name='root_agent',
model=mockModel,
sub_agents=[sub_agent_1],
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the transfer.
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
('sub_agent_1_1', 'response1'),
('sub_agent_1_2', transfer_call_part('sub_agent_1_2_1')),
('sub_agent_1_2', TRANSFER_RESPONSE_PART),
('sub_agent_1_2_1', 'response2'),
('sub_agent_1_3', 'response3'),
]
# root_agent should still be the current agent because sub_agent_1 is sequential.
assert utils.simplify_events(runner.run('test2')) == [
('root_agent', 'response4'),
]
def test_auto_to_loop():
response = [
transfer_call_part('sub_agent_1'),
# sub_agent_1 responds directly instead of transfering.
'response1',
'response2',
'response3',
Part.from_function_call(name='exit_loop', args={}),
'response4',
'response5',
]
mockModel = utils.MockModel.create(responses=response)
# root (auto) - sub_agent_1 (loop) - sub_agent_1_1 (single)
# \ sub_agent_1_2 (single)
sub_agent_1_1 = Agent(
name='sub_agent_1_1',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
sub_agent_1_2 = Agent(
name='sub_agent_1_2',
model=mockModel,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
tools=[exit_loop],
)
sub_agent_1 = LoopAgent(
name='sub_agent_1',
sub_agents=[sub_agent_1_1, sub_agent_1_2],
)
root_agent = Agent(
name='root_agent',
model=mockModel,
sub_agents=[sub_agent_1],
)
runner = utils.InMemoryRunner(root_agent)
# Asserts the transfer.
assert utils.simplify_events(runner.run('test1')) == [
# Transfers to sub_agent_1.
('root_agent', transfer_call_part('sub_agent_1')),
('root_agent', TRANSFER_RESPONSE_PART),
# Loops.
('sub_agent_1_1', 'response1'),
('sub_agent_1_2', 'response2'),
('sub_agent_1_1', 'response3'),
# Exits.
('sub_agent_1_2', Part.from_function_call(name='exit_loop', args={})),
(
'sub_agent_1_2',
Part.from_function_response(name='exit_loop', response={}),
),
# root_agent summarizes.
('root_agent', 'response4'),
]
# root_agent should still be the current agent because sub_agent_1 is loop.
assert utils.simplify_events(runner.run('test2')) == [
('root_agent', 'response5'),
]

View File

@@ -0,0 +1,244 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents import Agent
from google.adk.tools import ToolContext
from google.adk.tools.long_running_tool import LongRunningFunctionTool
from google.genai.types import Part
from ... import utils
def test_async_function():
responses = [
Part.from_function_call(name='increase_by_one', args={'x': 1}),
'response1',
'response2',
'response3',
'response4',
]
mockModel = utils.MockModel.create(responses=responses)
function_called = 0
def increase_by_one(x: int, tool_context: ToolContext) -> int:
nonlocal function_called
function_called += 1
return {'status': 'pending'}
# Calls the first time.
agent = Agent(
name='root_agent',
model=mockModel,
tools=[LongRunningFunctionTool(func=increase_by_one)],
)
runner = utils.InMemoryRunner(agent)
events = runner.run('test1')
# Asserts the requests.
assert len(mockModel.requests) == 2
# 1 item: user content
assert mockModel.requests[0].contents == [
utils.UserContent('test1'),
]
increase_by_one_call = Part.from_function_call(
name='increase_by_one', args={'x': 1}
)
pending_response = Part.from_function_response(
name='increase_by_one', response={'status': 'pending'}
)
assert utils.simplify_contents(mockModel.requests[1].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', pending_response),
]
# Asserts the function calls.
assert function_called == 1
# Asserts the responses.
assert utils.simplify_events(events) == [
(
'root_agent',
Part.from_function_call(name='increase_by_one', args={'x': 1}),
),
(
'root_agent',
Part.from_function_response(
name='increase_by_one', response={'status': 'pending'}
),
),
('root_agent', 'response1'),
]
assert events[0].long_running_tool_ids
# Updates with another pending progress.
still_waiting_response = Part.from_function_response(
name='increase_by_one', response={'status': 'still waiting'}
)
events = runner.run(utils.UserContent(still_waiting_response))
# We have one new request.
assert len(mockModel.requests) == 3
assert utils.simplify_contents(mockModel.requests[2].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', still_waiting_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response2')]
# Calls when the result is ready.
result_response = Part.from_function_response(
name='increase_by_one', response={'result': 2}
)
events = runner.run(utils.UserContent(result_response))
# We have one new request.
assert len(mockModel.requests) == 4
assert utils.simplify_contents(mockModel.requests[3].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', result_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response3')]
# Calls when the result is ready. Here we still accept the result and do
# another summarization. Whether this is the right behavior is TBD.
another_result_response = Part.from_function_response(
name='increase_by_one', response={'result': 3}
)
events = runner.run(utils.UserContent(another_result_response))
# We have one new request.
assert len(mockModel.requests) == 5
assert utils.simplify_contents(mockModel.requests[4].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', another_result_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response4')]
# At the end, function_called should still be 1.
assert function_called == 1
def test_async_function_with_none_response():
responses = [
Part.from_function_call(name='increase_by_one', args={'x': 1}),
'response1',
'response2',
'response3',
'response4',
]
mockModel = utils.MockModel.create(responses=responses)
function_called = 0
def increase_by_one(x: int, tool_context: ToolContext) -> int:
nonlocal function_called
function_called += 1
return 'pending'
# Calls the first time.
agent = Agent(
name='root_agent',
model=mockModel,
tools=[LongRunningFunctionTool(func=increase_by_one)],
)
runner = utils.InMemoryRunner(agent)
events = runner.run('test1')
# Asserts the requests.
assert len(mockModel.requests) == 2
# 1 item: user content
assert mockModel.requests[0].contents == [
utils.UserContent('test1'),
]
increase_by_one_call = Part.from_function_call(
name='increase_by_one', args={'x': 1}
)
assert utils.simplify_contents(mockModel.requests[1].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
(
'user',
Part.from_function_response(
name='increase_by_one', response={'result': 'pending'}
),
),
]
# Asserts the function calls.
assert function_called == 1
# Asserts the responses.
assert utils.simplify_events(events) == [
(
'root_agent',
Part.from_function_call(name='increase_by_one', args={'x': 1}),
),
(
'root_agent',
Part.from_function_response(
name='increase_by_one', response={'result': 'pending'}
),
),
('root_agent', 'response1'),
]
# Updates with another pending progress.
still_waiting_response = Part.from_function_response(
name='increase_by_one', response={'status': 'still waiting'}
)
events = runner.run(utils.UserContent(still_waiting_response))
# We have one new request.
assert len(mockModel.requests) == 3
assert utils.simplify_contents(mockModel.requests[2].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', still_waiting_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response2')]
# Calls when the result is ready.
result_response = Part.from_function_response(
name='increase_by_one', response={'result': 2}
)
events = runner.run(utils.UserContent(result_response))
# We have one new request.
assert len(mockModel.requests) == 4
assert utils.simplify_contents(mockModel.requests[3].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', result_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response3')]
# Calls when the result is ready. Here we still accept the result and do
# another summarization. Whether this is the right behavior is TBD.
another_result_response = Part.from_function_response(
name='increase_by_one', response={'result': 3}
)
events = runner.run(utils.UserContent(another_result_response))
# We have one new request.
assert len(mockModel.requests) == 5
assert utils.simplify_contents(mockModel.requests[4].contents) == [
('user', 'test1'),
('model', increase_by_one_call),
('user', another_result_response),
]
assert utils.simplify_events(events) == [('root_agent', 'response4')]
# At the end, function_called should still be 1.
assert function_called == 1

View File

@@ -0,0 +1,346 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
from google.adk.agents import Agent
from google.adk.auth import AuthConfig
from google.adk.auth import AuthCredential
from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.flows.llm_flows import functions
from google.adk.tools import AuthToolArguments
from google.adk.tools import ToolContext
from google.genai import types
from ... import utils
def function_call(function_call_id, name, args: dict[str, Any]) -> types.Part:
part = types.Part.from_function_call(name=name, args=args)
part.function_call.id = function_call_id
return part
def test_function_request_euc():
responses = [
[
types.Part.from_function_call(name='call_external_api1', args={}),
types.Part.from_function_call(name='call_external_api2', args={}),
],
[
types.Part.from_text(text='response1'),
],
]
auth_config1 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
),
),
)
auth_config2 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
),
),
)
mock_model = utils.MockModel.create(responses=responses)
def call_external_api1(tool_context: ToolContext) -> Optional[int]:
tool_context.request_credential(auth_config1)
def call_external_api2(tool_context: ToolContext) -> Optional[int]:
tool_context.request_credential(auth_config2)
agent = Agent(
name='root_agent',
model=mock_model,
tools=[call_external_api1, call_external_api2],
)
runner = utils.InMemoryRunner(agent)
events = runner.run('test')
assert events[0].content.parts[0].function_call is not None
assert events[0].content.parts[1].function_call is not None
auth_configs = list(events[2].actions.requested_auth_configs.values())
exchanged_auth_config1 = auth_configs[0]
exchanged_auth_config2 = auth_configs[1]
assert exchanged_auth_config1.auth_scheme == auth_config1.auth_scheme
assert (
exchanged_auth_config1.raw_auth_credential
== auth_config1.raw_auth_credential
)
assert (
exchanged_auth_config1.exchanged_auth_credential.oauth2.auth_uri
is not None
)
assert exchanged_auth_config2.auth_scheme == auth_config2.auth_scheme
assert (
exchanged_auth_config2.raw_auth_credential
== auth_config2.raw_auth_credential
)
assert (
exchanged_auth_config2.exchanged_auth_credential.oauth2.auth_uri
is not None
)
function_call_ids = list(events[2].actions.requested_auth_configs.keys())
for idx, part in enumerate(events[1].content.parts):
reqeust_euc_function_call = part.function_call
assert reqeust_euc_function_call is not None
assert (
reqeust_euc_function_call.name
== functions.REQUEST_EUC_FUNCTION_CALL_NAME
)
args = AuthToolArguments.model_validate(reqeust_euc_function_call.args)
assert args.function_call_id == function_call_ids[idx]
args.auth_config.auth_scheme.model_extra.clear()
assert args.auth_config.auth_scheme == auth_configs[idx].auth_scheme
assert (
args.auth_config.raw_auth_credential
== auth_configs[idx].raw_auth_credential
)
def test_function_get_auth_response():
id_1 = 'id_1'
id_2 = 'id_2'
responses = [
[
function_call(id_1, 'call_external_api1', {}),
function_call(id_2, 'call_external_api2', {}),
],
[
types.Part.from_text(text='response1'),
],
[
types.Part.from_text(text='response2'),
],
]
mock_model = utils.MockModel.create(responses=responses)
function_invoked = 0
auth_config1 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
),
),
)
auth_config2 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
),
),
)
auth_response1 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
),
),
exchanged_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
token={'access_token': 'token1'},
),
),
)
auth_response2 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
),
),
exchanged_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
token={'access_token': 'token2'},
),
),
)
def call_external_api1(tool_context: ToolContext) -> int:
nonlocal function_invoked
function_invoked += 1
auth_response = tool_context.get_auth_response(auth_config1)
if not auth_response:
tool_context.request_credential(auth_config1)
return
assert auth_response == auth_response1.exchanged_auth_credential
return 1
def call_external_api2(tool_context: ToolContext) -> int:
nonlocal function_invoked
function_invoked += 1
auth_response = tool_context.get_auth_response(auth_config2)
if not auth_response:
tool_context.request_credential(auth_config2)
return
assert auth_response == auth_response2.exchanged_auth_credential
return 2
agent = Agent(
name='root_agent',
model=mock_model,
tools=[call_external_api1, call_external_api2],
)
runner = utils.InMemoryRunner(agent)
runner.run('test')
request_euc_function_call_event = runner.session.events[-3]
function_response1 = types.FunctionResponse(
name=request_euc_function_call_event.content.parts[0].function_call.name,
response=auth_response1.model_dump(),
)
function_response1.id = request_euc_function_call_event.content.parts[
0
].function_call.id
function_response2 = types.FunctionResponse(
name=request_euc_function_call_event.content.parts[1].function_call.name,
response=auth_response2.model_dump(),
)
function_response2.id = request_euc_function_call_event.content.parts[
1
].function_call.id
runner.run(
new_message=types.Content(
role='user',
parts=[
types.Part(function_response=function_response1),
types.Part(function_response=function_response2),
],
),
)
assert function_invoked == 4
reqeust = mock_model.requests[-1]
content = reqeust.contents[-1]
parts = content.parts
assert len(parts) == 2
assert parts[0].function_response.name == 'call_external_api1'
assert parts[0].function_response.response == {'result': 1}
assert parts[1].function_response.name == 'call_external_api2'
assert parts[1].function_response.response == {'result': 2}

View File

@@ -0,0 +1,93 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from google.adk.agents import Agent
from google.genai import types
from ... import utils
def function_call(args: dict[str, Any]) -> types.Part:
return types.Part.from_function_call(name='increase_by_one', args=args)
def function_response(response: dict[str, Any]) -> types.Part:
return types.Part.from_function_response(
name='increase_by_one', response=response
)
def test_sequential_calls():
responses = [
function_call({'x': 1}),
function_call({'x': 2}),
function_call({'x': 3}),
'response1',
]
mockModel = utils.MockModel.create(responses=responses)
function_called = 0
def increase_by_one(x: int) -> int:
nonlocal function_called
function_called += 1
return x + 1
agent = Agent(name='root_agent', model=mockModel, tools=[increase_by_one])
runner = utils.InMemoryRunner(agent)
result = utils.simplify_events(runner.run('test'))
assert result == [
('root_agent', function_call({'x': 1})),
('root_agent', function_response({'result': 2})),
('root_agent', function_call({'x': 2})),
('root_agent', function_response({'result': 3})),
('root_agent', function_call({'x': 3})),
('root_agent', function_response({'result': 4})),
('root_agent', 'response1'),
]
# Asserts the requests.
assert len(mockModel.requests) == 4
# 1 item: user content
assert utils.simplify_contents(mockModel.requests[0].contents) == [
('user', 'test')
]
# 3 items: user content, functaion call / response for the 1st call
assert utils.simplify_contents(mockModel.requests[1].contents) == [
('user', 'test'),
('model', function_call({'x': 1})),
('user', function_response({'result': 2})),
]
# 5 items: user content, functaion call / response for two calls
assert utils.simplify_contents(mockModel.requests[2].contents) == [
('user', 'test'),
('model', function_call({'x': 1})),
('user', function_response({'result': 2})),
('model', function_call({'x': 2})),
('user', function_response({'result': 3})),
]
# 7 items: user content, functaion call / response for three calls
assert utils.simplify_contents(mockModel.requests[3].contents) == [
('user', 'test'),
('model', function_call({'x': 1})),
('user', function_response({'result': 2})),
('model', function_call({'x': 2})),
('user', function_response({'result': 3})),
('model', function_call({'x': 3})),
('user', function_response({'result': 4})),
]
# Asserts the function calls.
assert function_called == 3

View File

@@ -0,0 +1,258 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from google.adk.agents import Agent
from google.adk.tools import ToolContext
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
import pytest
from ... import utils
def test_simple_function():
function_call_1 = types.Part.from_function_call(
name='increase_by_one', args={'x': 1}
)
function_respones_2 = types.Part.from_function_response(
name='increase_by_one', response={'result': 2}
)
responses: list[types.Content] = [
function_call_1,
'response1',
'response2',
'response3',
'response4',
]
function_called = 0
mock_model = utils.MockModel.create(responses=responses)
def increase_by_one(x: int) -> int:
nonlocal function_called
function_called += 1
return x + 1
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', function_call_1),
('root_agent', function_respones_2),
('root_agent', 'response1'),
]
# Asserts the requests.
assert utils.simplify_contents(mock_model.requests[0].contents) == [
('user', 'test')
]
assert utils.simplify_contents(mock_model.requests[1].contents) == [
('user', 'test'),
('model', function_call_1),
('user', function_respones_2),
]
# Asserts the function calls.
assert function_called == 1
@pytest.mark.asyncio
async def test_async_function():
function_calls = [
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
]
function_responses = [
types.Part.from_function_response(
name='increase_by_one', response={'result': 2}
),
types.Part.from_function_response(
name='multiple_by_two', response={'result': 4}
),
types.Part.from_function_response(
name='multiple_by_two_sync', response={'result': 6}
),
]
responses: list[types.Content] = [
function_calls,
'response1',
'response2',
'response3',
'response4',
]
function_called = 0
mock_model = utils.MockModel.create(responses=responses)
async def increase_by_one(x: int) -> int:
nonlocal function_called
function_called += 1
return x + 1
async def multiple_by_two(x: int) -> int:
nonlocal function_called
function_called += 1
return x * 2
def multiple_by_two_sync(x: int) -> int:
nonlocal function_called
function_called += 1
return x * 2
agent = Agent(
name='root_agent',
model=mock_model,
tools=[increase_by_one, multiple_by_two, multiple_by_two_sync],
)
runner = utils.TestInMemoryRunner(agent)
events = await runner.run_async_with_new_session('test')
assert utils.simplify_events(events) == [
('root_agent', function_calls),
('root_agent', function_responses),
('root_agent', 'response1'),
]
# Asserts the requests.
assert utils.simplify_contents(mock_model.requests[0].contents) == [
('user', 'test')
]
assert utils.simplify_contents(mock_model.requests[1].contents) == [
('user', 'test'),
('model', function_calls),
('user', function_responses),
]
# Asserts the function calls.
assert function_called == 3
@pytest.mark.asyncio
async def test_function_tool():
function_calls = [
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
]
function_responses = [
types.Part.from_function_response(
name='increase_by_one', response={'result': 2}
),
types.Part.from_function_response(
name='multiple_by_two', response={'result': 4}
),
types.Part.from_function_response(
name='multiple_by_two_sync', response={'result': 6}
),
]
responses: list[types.Content] = [
function_calls,
'response1',
'response2',
'response3',
'response4',
]
function_called = 0
mock_model = utils.MockModel.create(responses=responses)
async def increase_by_one(x: int) -> int:
nonlocal function_called
function_called += 1
return x + 1
async def multiple_by_two(x: int) -> int:
nonlocal function_called
function_called += 1
return x * 2
def multiple_by_two_sync(x: int) -> int:
nonlocal function_called
function_called += 1
return x * 2
class TestTool(FunctionTool):
def __init__(self, func: Callable[..., Any]):
super().__init__(func=func)
wrapped_increase_by_one = TestTool(func=increase_by_one)
agent = Agent(
name='root_agent',
model=mock_model,
tools=[wrapped_increase_by_one, multiple_by_two, multiple_by_two_sync],
)
runner = utils.TestInMemoryRunner(agent)
events = await runner.run_async_with_new_session('test')
assert utils.simplify_events(events) == [
('root_agent', function_calls),
('root_agent', function_responses),
('root_agent', 'response1'),
]
# Asserts the requests.
assert utils.simplify_contents(mock_model.requests[0].contents) == [
('user', 'test')
]
assert utils.simplify_contents(mock_model.requests[1].contents) == [
('user', 'test'),
('model', function_calls),
('user', function_responses),
]
# Asserts the function calls.
assert function_called == 3
def test_update_state():
mock_model = utils.MockModel.create(
responses=[
types.Part.from_function_call(name='update_state', args={}),
'response1',
]
)
def update_state(tool_context: ToolContext):
tool_context.state['x'] = 1
agent = Agent(name='root_agent', model=mock_model, tools=[update_state])
runner = utils.InMemoryRunner(agent)
runner.run('test')
assert runner.session.state['x'] == 1
def test_function_call_id():
responses = [
types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
def increase_by_one(x: int) -> int:
return x + 1
agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
runner = utils.InMemoryRunner(agent)
events = runner.run('test')
for reqeust in mock_model.requests:
for content in reqeust.contents:
for part in content.parts:
if part.function_call:
assert part.function_call.id is None
if part.function_response:
assert part.function_response.id is None
assert events[0].content.parts[0].function_call.id.startswith('adk-')
assert events[1].content.parts[0].function_response.id.startswith('adk-')

View File

@@ -0,0 +1,66 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents import Agent
from google.adk.flows.llm_flows import identity
from google.adk.models import LlmRequest
from google.genai import types
import pytest
from ... import utils
@pytest.mark.asyncio
async def test_no_description():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(model="gemini-1.5-flash", name="agent")
invocation_context = utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"""You are an agent. Your internal name is "agent"."""
)
@pytest.mark.asyncio
async def test_with_description():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
description="test description",
)
invocation_context = utils.create_invocation_context(agent=agent)
async for _ in identity.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == "\n\n".join([
'You are an agent. Your internal name is "agent".',
' The description about you is "test description"',
])

View File

@@ -0,0 +1,164 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents import Agent
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.flows.llm_flows import instructions
from google.adk.models import LlmRequest
from google.adk.sessions import Session
from google.genai import types
import pytest
from ... import utils
@pytest.mark.asyncio
async def test_build_system_instruction():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=("""Use the echo_info tool to echo { customerId }, \
{{customer_int }, { non-identifier-float}}, \
{'key1': 'value1'} and {{'key2': 'value2'}}."""),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"""Use the echo_info tool to echo 1234567890, 30, \
{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
)
@pytest.mark.asyncio
async def test_function_system_instruction():
def build_function_instruction(readonly_context: ReadonlyContext) -> str:
return (
"This is the function agent instruction for invocation:"
f" {readonly_context.invocation_id}."
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the function agent instruction for invocation: test_id."
)
@pytest.mark.asyncio
async def test_global_system_instruction():
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction="This is the sub agent instruction.",
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction="This is the global instruction.",
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_build_system_instruction_with_namespace():
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=(
"""Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
),
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={
"customerId": "1234567890",
"app:key": "app_value",
"user:key": "user_value",
},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"""Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
)

View File

@@ -0,0 +1,142 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Optional
from google.adk.agents import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
from pydantic import BaseModel
import pytest
from ... import utils
class MockBeforeModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAfterModelCallback(BaseModel):
mock_response: str
def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
def test_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback'
),
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', 'before_model_callback'),
]
def test_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=noop_callback,
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', 'model_response'),
]
def test_before_model_callback_end():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback',
),
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', 'before_model_callback'),
]
def test_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAfterModelCallback(
mock_response='after_model_callback'
),
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', 'after_model_callback'),
]
@pytest.mark.asyncio
async def test_after_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [('root_agent', 'model_response')]

View File

@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.agents import Agent
from google.adk.tools import ToolContext
from google.genai.types import Part
from pydantic import BaseModel
from ... import utils
def test_output_schema():
class CustomOutput(BaseModel):
custom_field: str
response = [
'response1',
]
mockModel = utils.MockModel.create(responses=response)
root_agent = Agent(
name='root_agent',
model=mockModel,
output_schema=CustomOutput,
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True,
)
runner = utils.InMemoryRunner(root_agent)
assert utils.simplify_events(runner.run('test1')) == [
('root_agent', 'response1'),
]
assert len(mockModel.requests) == 1
assert mockModel.requests[0].config.response_schema == CustomOutput
assert mockModel.requests[0].config.response_mime_type == 'application/json'

View File

@@ -0,0 +1,269 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from google.adk.agents import Agent
from google.adk.tools import BaseTool
from google.adk.tools import ToolContext
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
from ... import utils
def simple_function(input_str: str) -> str:
return {'result': input_str}
class MockBeforeToolCallback(BaseModel):
mock_response: dict[str, object]
modify_tool_request: bool = False
def __call__(
self,
tool: BaseTool,
args: dict[str, Any],
tool_context: ToolContext,
) -> dict[str, object]:
if self.modify_tool_request:
args['input_str'] = 'modified_input'
return None
return self.mock_response
class MockAfterToolCallback(BaseModel):
mock_response: dict[str, object]
modify_tool_request: bool = False
modify_tool_response: bool = False
def __call__(
self,
tool: BaseTool,
args: dict[str, Any],
tool_context: ToolContext,
tool_response: dict[str, Any] = None,
) -> dict[str, object]:
if self.modify_tool_request:
args['input_str'] = 'modified_input'
return None
if self.modify_tool_response:
tool_response['result'] = 'modified_output'
return tool_response
return self.mock_response
def noop_callback(
**kwargs,
) -> dict[str, object]:
pass
def test_before_tool_callback():
responses = [
types.Part.from_function_call(name='simple_function', args={}),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_tool_callback=MockBeforeToolCallback(
mock_response={'test': 'before_tool_callback'}
),
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', Part.from_function_call(name='simple_function', args={})),
(
'root_agent',
Part.from_function_response(
name='simple_function', response={'test': 'before_tool_callback'}
),
),
('root_agent', 'response1'),
]
def test_before_tool_callback_noop():
responses = [
types.Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_tool_callback=noop_callback,
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
(
'root_agent',
Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
),
(
'root_agent',
Part.from_function_response(
name='simple_function',
response={'result': 'simple_function_call'},
),
),
('root_agent', 'response1'),
]
def test_before_tool_callback_modify_tool_request():
responses = [
types.Part.from_function_call(name='simple_function', args={}),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_tool_callback=MockBeforeToolCallback(
mock_response={'test': 'before_tool_callback'},
modify_tool_request=True,
),
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
('root_agent', Part.from_function_call(name='simple_function', args={})),
(
'root_agent',
Part.from_function_response(
name='simple_function',
response={'result': 'modified_input'},
),
),
('root_agent', 'response1'),
]
def test_after_tool_callback():
responses = [
types.Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_tool_callback=MockAfterToolCallback(
mock_response={'test': 'after_tool_callback'}
),
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
(
'root_agent',
Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
),
(
'root_agent',
Part.from_function_response(
name='simple_function', response={'test': 'after_tool_callback'}
),
),
('root_agent', 'response1'),
]
def test_after_tool_callback_noop():
responses = [
types.Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_tool_callback=noop_callback,
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
(
'root_agent',
Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
),
(
'root_agent',
Part.from_function_response(
name='simple_function',
response={'result': 'simple_function_call'},
),
),
('root_agent', 'response1'),
]
def test_after_tool_callback_modify_tool_response():
responses = [
types.Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
'response1',
]
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_tool_callback=MockAfterToolCallback(
mock_response={'result': 'after_tool_callback'},
modify_tool_response=True,
),
tools=[simple_function],
)
runner = utils.InMemoryRunner(agent)
assert utils.simplify_events(runner.run('test')) == [
(
'root_agent',
Part.from_function_call(
name='simple_function', args={'input_str': 'simple_function_call'}
),
),
(
'root_agent',
Part.from_function_response(
name='simple_function',
response={'result': 'modified_output'},
),
),
('root_agent', 'response1'),
]