mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
fix: Call all tools in parallel calls during partial authentication
Copybara import of the project: -- ffd6184d7e402b0787b0fa37fc09cd519adcc7f3 by Calvin Giles <calvin.giles@trademe.co.nz>: fix: Call all tools in parallel calls during partial authentication -- c71782a582ba825dbe2246cdb5be3f273ca90dca by seanzhou1023 <seanzhou1023@gmail.com>: Update auth_preprocessor.py -- 843af6b1bc0bc6291cb9cb23acf11840098ba6dd by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py -- 955e3fa852420ecbf196583caa3cf86b7b80ab56 by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/853 from calvingiles:fix-parallel-auth-tool-calls f44671e37b9fe44a25c9b1c0c25a26fc634b011c PiperOrigin-RevId: 765639904
This commit is contained in:
parent
036f954a2a
commit
0e72efb439
@ -100,23 +100,24 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
function_calls = event.get_function_calls()
|
function_calls = event.get_function_calls()
|
||||||
if not function_calls:
|
if not function_calls:
|
||||||
continue
|
continue
|
||||||
for function_call in function_calls:
|
|
||||||
function_response_event = None
|
if any([
|
||||||
if function_call.id in tools_to_resume:
|
function_call.id in tools_to_resume
|
||||||
function_response_event = await functions.handle_function_calls_async(
|
for function_call in function_calls
|
||||||
invocation_context,
|
]):
|
||||||
event,
|
if function_response_event := await functions.handle_function_calls_async(
|
||||||
{
|
invocation_context,
|
||||||
tool.name: tool
|
event,
|
||||||
for tool in await agent.canonical_tools(
|
{
|
||||||
ReadonlyContext(invocation_context)
|
tool.name: tool
|
||||||
)
|
for tool in await agent.canonical_tools(
|
||||||
},
|
ReadonlyContext(invocation_context)
|
||||||
# there could be parallel function calls that require auth
|
)
|
||||||
# auth response would be a dict keyed by function call id
|
},
|
||||||
tools_to_resume,
|
# there could be parallel function calls that require auth
|
||||||
)
|
# auth response would be a dict keyed by function call id
|
||||||
if function_response_event:
|
tools_to_resume,
|
||||||
|
):
|
||||||
yield function_response_event
|
yield function_response_event
|
||||||
return
|
return
|
||||||
return
|
return
|
||||||
|
@ -170,10 +170,10 @@ def _rearrange_events_for_latest_function_response(
|
|||||||
for idx in range(function_call_event_idx + 1, len(events) - 1):
|
for idx in range(function_call_event_idx + 1, len(events) - 1):
|
||||||
event = events[idx]
|
event = events[idx]
|
||||||
function_responses = event.get_function_responses()
|
function_responses = event.get_function_responses()
|
||||||
if (
|
if function_responses and any([
|
||||||
function_responses
|
function_response.id in function_responses_ids
|
||||||
and function_responses[0].id in function_responses_ids
|
for function_response in function_responses
|
||||||
):
|
]):
|
||||||
function_response_events.append(event)
|
function_response_events.append(event)
|
||||||
function_response_events.append(events[-1])
|
function_response_events.append(events[-1])
|
||||||
|
|
||||||
|
@ -344,3 +344,215 @@ def test_function_get_auth_response():
|
|||||||
assert parts[0].function_response.response == {'result': 1}
|
assert parts[0].function_response.response == {'result': 1}
|
||||||
assert parts[1].function_response.name == 'call_external_api2'
|
assert parts[1].function_response.name == 'call_external_api2'
|
||||||
assert parts[1].function_response.response == {'result': 2}
|
assert parts[1].function_response.response == {'result': 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_get_auth_response_partial():
|
||||||
|
id_1 = 'id_1'
|
||||||
|
id_2 = 'id_2'
|
||||||
|
responses = [
|
||||||
|
[
|
||||||
|
function_call(id_1, 'call_external_api1', {}),
|
||||||
|
function_call(id_2, 'call_external_api2', {}),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
types.Part.from_text(text='response1'),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
types.Part.from_text(text='response2'),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_model = testing_utils.MockModel.create(responses=responses)
|
||||||
|
function_invoked = 0
|
||||||
|
|
||||||
|
auth_config1 = AuthConfig(
|
||||||
|
auth_scheme=OAuth2(
|
||||||
|
flows=OAuthFlows(
|
||||||
|
authorizationCode=OAuthFlowAuthorizationCode(
|
||||||
|
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||||
|
tokenUrl='https://oauth2.googleapis.com/token',
|
||||||
|
scopes={
|
||||||
|
'https://www.googleapis.com/auth/calendar': (
|
||||||
|
'See, edit, share, and permanently delete all the'
|
||||||
|
' calendars you can access using Google Calendar'
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_1',
|
||||||
|
client_secret='oauth_client_secret1',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
auth_config2 = AuthConfig(
|
||||||
|
auth_scheme=OAuth2(
|
||||||
|
flows=OAuthFlows(
|
||||||
|
authorizationCode=OAuthFlowAuthorizationCode(
|
||||||
|
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||||
|
tokenUrl='https://oauth2.googleapis.com/token',
|
||||||
|
scopes={
|
||||||
|
'https://www.googleapis.com/auth/calendar': (
|
||||||
|
'See, edit, share, and permanently delete all the'
|
||||||
|
' calendars you can access using Google Calendar'
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_2',
|
||||||
|
client_secret='oauth_client_secret2',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_response1 = AuthConfig(
|
||||||
|
auth_scheme=OAuth2(
|
||||||
|
flows=OAuthFlows(
|
||||||
|
authorizationCode=OAuthFlowAuthorizationCode(
|
||||||
|
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||||
|
tokenUrl='https://oauth2.googleapis.com/token',
|
||||||
|
scopes={
|
||||||
|
'https://www.googleapis.com/auth/calendar': (
|
||||||
|
'See, edit, share, and permanently delete all the'
|
||||||
|
' calendars you can access using Google Calendar'
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_1',
|
||||||
|
client_secret='oauth_client_secret1',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
exchanged_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_1',
|
||||||
|
client_secret='oauth_client_secret1',
|
||||||
|
access_token='token1',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
auth_response2 = AuthConfig(
|
||||||
|
auth_scheme=OAuth2(
|
||||||
|
flows=OAuthFlows(
|
||||||
|
authorizationCode=OAuthFlowAuthorizationCode(
|
||||||
|
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||||
|
tokenUrl='https://oauth2.googleapis.com/token',
|
||||||
|
scopes={
|
||||||
|
'https://www.googleapis.com/auth/calendar': (
|
||||||
|
'See, edit, share, and permanently delete all the'
|
||||||
|
' calendars you can access using Google Calendar'
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
raw_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_2',
|
||||||
|
client_secret='oauth_client_secret2',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
exchanged_auth_credential=AuthCredential(
|
||||||
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
|
oauth2=OAuth2Auth(
|
||||||
|
client_id='oauth_client_id_2',
|
||||||
|
client_secret='oauth_client_secret2',
|
||||||
|
access_token='token2',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_external_api1(tool_context: ToolContext) -> int:
|
||||||
|
nonlocal function_invoked
|
||||||
|
function_invoked += 1
|
||||||
|
auth_response = tool_context.get_auth_response(auth_config1)
|
||||||
|
if not auth_response:
|
||||||
|
tool_context.request_credential(auth_config1)
|
||||||
|
return
|
||||||
|
assert auth_response == auth_response1.exchanged_auth_credential
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def call_external_api2(tool_context: ToolContext) -> int:
|
||||||
|
nonlocal function_invoked
|
||||||
|
function_invoked += 1
|
||||||
|
auth_response = tool_context.get_auth_response(auth_config2)
|
||||||
|
if not auth_response:
|
||||||
|
tool_context.request_credential(auth_config2)
|
||||||
|
return
|
||||||
|
assert auth_response == auth_response2.exchanged_auth_credential
|
||||||
|
return 2
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
tools=[call_external_api1, call_external_api2],
|
||||||
|
)
|
||||||
|
runner = testing_utils.InMemoryRunner(agent)
|
||||||
|
runner.run('test')
|
||||||
|
request_euc_function_call_event = runner.session.events[-3]
|
||||||
|
function_response1 = types.FunctionResponse(
|
||||||
|
name=request_euc_function_call_event.content.parts[0].function_call.name,
|
||||||
|
response=auth_response1.model_dump(),
|
||||||
|
)
|
||||||
|
function_response1.id = request_euc_function_call_event.content.parts[
|
||||||
|
0
|
||||||
|
].function_call.id
|
||||||
|
|
||||||
|
function_response2 = types.FunctionResponse(
|
||||||
|
name=request_euc_function_call_event.content.parts[1].function_call.name,
|
||||||
|
response=auth_response2.model_dump(),
|
||||||
|
)
|
||||||
|
function_response2.id = request_euc_function_call_event.content.parts[
|
||||||
|
1
|
||||||
|
].function_call.id
|
||||||
|
runner.run(
|
||||||
|
new_message=types.Content(
|
||||||
|
role='user',
|
||||||
|
parts=[
|
||||||
|
types.Part(function_response=function_response1),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert function_invoked == 3
|
||||||
|
assert len(mock_model.requests) == 3
|
||||||
|
request = mock_model.requests[-1]
|
||||||
|
content = request.contents[-1]
|
||||||
|
parts = content.parts
|
||||||
|
assert len(parts) == 2
|
||||||
|
assert parts[0].function_response.name == 'call_external_api1'
|
||||||
|
assert parts[0].function_response.response == {'result': 1}
|
||||||
|
assert parts[1].function_response.name == 'call_external_api2'
|
||||||
|
assert parts[1].function_response.response == {'result': None}
|
||||||
|
|
||||||
|
runner.run(
|
||||||
|
new_message=types.Content(
|
||||||
|
role='user',
|
||||||
|
parts=[
|
||||||
|
types.Part(function_response=function_response2),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# assert function_invoked == 4
|
||||||
|
assert len(mock_model.requests) == 4
|
||||||
|
request = mock_model.requests[-1]
|
||||||
|
content = request.contents[-1]
|
||||||
|
parts = content.parts
|
||||||
|
assert len(parts) == 2
|
||||||
|
assert parts[0].function_response.name == 'call_external_api1'
|
||||||
|
assert parts[0].function_response.response == {'result': None}
|
||||||
|
assert parts[1].function_response.name == 'call_external_api2'
|
||||||
|
assert parts[1].function_response.response == {'result': 2}
|
||||||
|
Loading…
Reference in New Issue
Block a user