fix: Call all tools in parallel calls during partial authentication

Copybara import of the project:

--
ffd6184d7e402b0787b0fa37fc09cd519adcc7f3 by Calvin Giles <calvin.giles@trademe.co.nz>:

fix: Call all tools in parallel calls during partial authentication

--
c71782a582ba825dbe2246cdb5be3f273ca90dca by seanzhou1023 <seanzhou1023@gmail.com>:

Update auth_preprocessor.py
--
843af6b1bc0bc6291cb9cb23acf11840098ba6dd by seanzhou1023 <seanzhou1023@gmail.com>:

Update test_functions_request_euc.py
--
955e3fa852420ecbf196583caa3cf86b7b80ab56 by seanzhou1023 <seanzhou1023@gmail.com>:

Update test_functions_request_euc.py

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/853 from calvingiles:fix-parallel-auth-tool-calls f44671e37b9fe44a25c9b1c0c25a26fc634b011c
PiperOrigin-RevId: 765639904
This commit is contained in:
Calvin Giles 2025-05-31 13:13:08 -07:00 committed by Copybara-Service
parent 036f954a2a
commit 0e72efb439
3 changed files with 234 additions and 21 deletions

View File

@ -100,23 +100,24 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
function_calls = event.get_function_calls()
if not function_calls:
continue
for function_call in function_calls:
function_response_event = None
if function_call.id in tools_to_resume:
function_response_event = await functions.handle_function_calls_async(
invocation_context,
event,
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id
tools_to_resume,
)
if function_response_event:
if any([
function_call.id in tools_to_resume
for function_call in function_calls
]):
if function_response_event := await functions.handle_function_calls_async(
invocation_context,
event,
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id
tools_to_resume,
):
yield function_response_event
return
return

View File

@ -170,10 +170,10 @@ def _rearrange_events_for_latest_function_response(
for idx in range(function_call_event_idx + 1, len(events) - 1):
event = events[idx]
function_responses = event.get_function_responses()
if (
function_responses
and function_responses[0].id in function_responses_ids
):
if function_responses and any([
function_response.id in function_responses_ids
for function_response in function_responses
]):
function_response_events.append(event)
function_response_events.append(events[-1])

View File

@ -344,3 +344,215 @@ def test_function_get_auth_response():
assert parts[0].function_response.response == {'result': 1}
assert parts[1].function_response.name == 'call_external_api2'
assert parts[1].function_response.response == {'result': 2}
def test_function_get_auth_response_partial():
id_1 = 'id_1'
id_2 = 'id_2'
responses = [
[
function_call(id_1, 'call_external_api1', {}),
function_call(id_2, 'call_external_api2', {}),
],
[
types.Part.from_text(text='response1'),
],
[
types.Part.from_text(text='response2'),
],
]
mock_model = testing_utils.MockModel.create(responses=responses)
function_invoked = 0
auth_config1 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
),
),
)
auth_config2 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
),
),
)
auth_response1 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
),
),
exchanged_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
access_token='token1',
),
),
)
auth_response2 = AuthConfig(
auth_scheme=OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token',
scopes={
'https://www.googleapis.com/auth/calendar': (
'See, edit, share, and permanently delete all the'
' calendars you can access using Google Calendar'
)
},
)
)
),
raw_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
),
),
exchanged_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
access_token='token2',
),
),
)
def call_external_api1(tool_context: ToolContext) -> int:
nonlocal function_invoked
function_invoked += 1
auth_response = tool_context.get_auth_response(auth_config1)
if not auth_response:
tool_context.request_credential(auth_config1)
return
assert auth_response == auth_response1.exchanged_auth_credential
return 1
def call_external_api2(tool_context: ToolContext) -> int:
nonlocal function_invoked
function_invoked += 1
auth_response = tool_context.get_auth_response(auth_config2)
if not auth_response:
tool_context.request_credential(auth_config2)
return
assert auth_response == auth_response2.exchanged_auth_credential
return 2
agent = Agent(
name='root_agent',
model=mock_model,
tools=[call_external_api1, call_external_api2],
)
runner = testing_utils.InMemoryRunner(agent)
runner.run('test')
request_euc_function_call_event = runner.session.events[-3]
function_response1 = types.FunctionResponse(
name=request_euc_function_call_event.content.parts[0].function_call.name,
response=auth_response1.model_dump(),
)
function_response1.id = request_euc_function_call_event.content.parts[
0
].function_call.id
function_response2 = types.FunctionResponse(
name=request_euc_function_call_event.content.parts[1].function_call.name,
response=auth_response2.model_dump(),
)
function_response2.id = request_euc_function_call_event.content.parts[
1
].function_call.id
runner.run(
new_message=types.Content(
role='user',
parts=[
types.Part(function_response=function_response1),
],
),
)
assert function_invoked == 3
assert len(mock_model.requests) == 3
request = mock_model.requests[-1]
content = request.contents[-1]
parts = content.parts
assert len(parts) == 2
assert parts[0].function_response.name == 'call_external_api1'
assert parts[0].function_response.response == {'result': 1}
assert parts[1].function_response.name == 'call_external_api2'
assert parts[1].function_response.response == {'result': None}
runner.run(
new_message=types.Content(
role='user',
parts=[
types.Part(function_response=function_response2),
],
),
)
# assert function_invoked == 4
assert len(mock_model.requests) == 4
request = mock_model.requests[-1]
content = request.contents[-1]
parts = content.parts
assert len(parts) == 2
assert parts[0].function_response.name == 'call_external_api1'
assert parts[0].function_response.response == {'result': None}
assert parts[1].function_response.name == 'call_external_api2'
assert parts[1].function_response.response == {'result': 2}