mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 07:04:51 -06:00
fix: Call all tools in parallel calls during partial authentication
Copybara import of the project: -- ffd6184d7e402b0787b0fa37fc09cd519adcc7f3 by Calvin Giles <calvin.giles@trademe.co.nz>: fix: Call all tools in parallel calls during partial authentication -- c71782a582ba825dbe2246cdb5be3f273ca90dca by seanzhou1023 <seanzhou1023@gmail.com>: Update auth_preprocessor.py -- 843af6b1bc0bc6291cb9cb23acf11840098ba6dd by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py -- 955e3fa852420ecbf196583caa3cf86b7b80ab56 by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/853 from calvingiles:fix-parallel-auth-tool-calls f44671e37b9fe44a25c9b1c0c25a26fc634b011c PiperOrigin-RevId: 765639904
This commit is contained in:
parent
036f954a2a
commit
0e72efb439
@ -100,23 +100,24 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
for function_call in function_calls:
|
||||
function_response_event = None
|
||||
if function_call.id in tools_to_resume:
|
||||
function_response_event = await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
)
|
||||
if function_response_event:
|
||||
|
||||
if any([
|
||||
function_call.id in tools_to_resume
|
||||
for function_call in function_calls
|
||||
]):
|
||||
if function_response_event := await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
):
|
||||
yield function_response_event
|
||||
return
|
||||
return
|
||||
|
@ -170,10 +170,10 @@ def _rearrange_events_for_latest_function_response(
|
||||
for idx in range(function_call_event_idx + 1, len(events) - 1):
|
||||
event = events[idx]
|
||||
function_responses = event.get_function_responses()
|
||||
if (
|
||||
function_responses
|
||||
and function_responses[0].id in function_responses_ids
|
||||
):
|
||||
if function_responses and any([
|
||||
function_response.id in function_responses_ids
|
||||
for function_response in function_responses
|
||||
]):
|
||||
function_response_events.append(event)
|
||||
function_response_events.append(events[-1])
|
||||
|
||||
|
@ -344,3 +344,215 @@ def test_function_get_auth_response():
|
||||
assert parts[0].function_response.response == {'result': 1}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': 2}
|
||||
|
||||
|
||||
def test_function_get_auth_response_partial():
|
||||
id_1 = 'id_1'
|
||||
id_2 = 'id_2'
|
||||
responses = [
|
||||
[
|
||||
function_call(id_1, 'call_external_api1', {}),
|
||||
function_call(id_2, 'call_external_api2', {}),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response1'),
|
||||
],
|
||||
[
|
||||
types.Part.from_text(text='response2'),
|
||||
],
|
||||
]
|
||||
|
||||
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||
function_invoked = 0
|
||||
|
||||
auth_config1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_config2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
auth_response1 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_1',
|
||||
client_secret='oauth_client_secret1',
|
||||
access_token='token1',
|
||||
),
|
||||
),
|
||||
)
|
||||
auth_response2 = AuthConfig(
|
||||
auth_scheme=OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token',
|
||||
scopes={
|
||||
'https://www.googleapis.com/auth/calendar': (
|
||||
'See, edit, share, and permanently delete all the'
|
||||
' calendars you can access using Google Calendar'
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
raw_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
),
|
||||
),
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='oauth_client_id_2',
|
||||
client_secret='oauth_client_secret2',
|
||||
access_token='token2',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def call_external_api1(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config1)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config1)
|
||||
return
|
||||
assert auth_response == auth_response1.exchanged_auth_credential
|
||||
return 1
|
||||
|
||||
def call_external_api2(tool_context: ToolContext) -> int:
|
||||
nonlocal function_invoked
|
||||
function_invoked += 1
|
||||
auth_response = tool_context.get_auth_response(auth_config2)
|
||||
if not auth_response:
|
||||
tool_context.request_credential(auth_config2)
|
||||
return
|
||||
assert auth_response == auth_response2.exchanged_auth_credential
|
||||
return 2
|
||||
|
||||
agent = Agent(
|
||||
name='root_agent',
|
||||
model=mock_model,
|
||||
tools=[call_external_api1, call_external_api2],
|
||||
)
|
||||
runner = testing_utils.InMemoryRunner(agent)
|
||||
runner.run('test')
|
||||
request_euc_function_call_event = runner.session.events[-3]
|
||||
function_response1 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[0].function_call.name,
|
||||
response=auth_response1.model_dump(),
|
||||
)
|
||||
function_response1.id = request_euc_function_call_event.content.parts[
|
||||
0
|
||||
].function_call.id
|
||||
|
||||
function_response2 = types.FunctionResponse(
|
||||
name=request_euc_function_call_event.content.parts[1].function_call.name,
|
||||
response=auth_response2.model_dump(),
|
||||
)
|
||||
function_response2.id = request_euc_function_call_event.content.parts[
|
||||
1
|
||||
].function_call.id
|
||||
runner.run(
|
||||
new_message=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert function_invoked == 3
|
||||
assert len(mock_model.requests) == 3
|
||||
request = mock_model.requests[-1]
|
||||
content = request.contents[-1]
|
||||
parts = content.parts
|
||||
assert len(parts) == 2
|
||||
assert parts[0].function_response.name == 'call_external_api1'
|
||||
assert parts[0].function_response.response == {'result': 1}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': None}
|
||||
|
||||
runner.run(
|
||||
new_message=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
# assert function_invoked == 4
|
||||
assert len(mock_model.requests) == 4
|
||||
request = mock_model.requests[-1]
|
||||
content = request.contents[-1]
|
||||
parts = content.parts
|
||||
assert len(parts) == 2
|
||||
assert parts[0].function_response.name == 'call_external_api1'
|
||||
assert parts[0].function_response.response == {'result': None}
|
||||
assert parts[1].function_response.name == 'call_external_api2'
|
||||
assert parts[1].function_response.response == {'result': 2}
|
||||
|
Loading…
Reference in New Issue
Block a user