diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 8ad30b7..0c964ed 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -100,23 +100,24 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): function_calls = event.get_function_calls() if not function_calls: continue - for function_call in function_calls: - function_response_event = None - if function_call.id in tools_to_resume: - function_response_event = await functions.handle_function_calls_async( - invocation_context, - event, - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # there could be parallel function calls that require auth - # auth response would be a dict keyed by function call id - tools_to_resume, - ) - if function_response_event: + + if any([ + function_call.id in tools_to_resume + for function_call in function_calls + ]): + if function_response_event := await functions.handle_function_calls_async( + invocation_context, + event, + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + # there could be parallel function calls that require auth + # auth response would be a dict keyed by function call id + tools_to_resume, + ): yield function_response_event return return diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index b37d8af..ea41888 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -170,10 +170,10 @@ def _rearrange_events_for_latest_function_response( for idx in range(function_call_event_idx + 1, len(events) - 1): event = events[idx] function_responses = event.get_function_responses() - if ( - function_responses - and function_responses[0].id in function_responses_ids - ): + if function_responses and any([ + function_response.id in function_responses_ids + for function_response in function_responses + ]): function_response_events.append(event) function_response_events.append(events[-1]) diff --git a/tests/unittests/flows/llm_flows/test_functions_request_euc.py b/tests/unittests/flows/llm_flows/test_functions_request_euc.py index 3d5fcae..6f8f112 100644 --- a/tests/unittests/flows/llm_flows/test_functions_request_euc.py +++ b/tests/unittests/flows/llm_flows/test_functions_request_euc.py @@ -344,3 +344,215 @@ def test_function_get_auth_response(): assert parts[0].function_response.response == {'result': 1} assert parts[1].function_response.name == 'call_external_api2' assert parts[1].function_response.response == {'result': 2} + + +def test_function_get_auth_response_partial(): + id_1 = 'id_1' + id_2 = 'id_2' + responses = [ + [ + function_call(id_1, 'call_external_api1', {}), + function_call(id_2, 'call_external_api2', {}), + ], + [ + types.Part.from_text(text='response1'), + ], + [ + types.Part.from_text(text='response2'), + ], + ] + + mock_model = testing_utils.MockModel.create(responses=responses) + function_invoked = 0 + + auth_config1 = AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl='https://accounts.google.com/o/oauth2/auth', + tokenUrl='https://oauth2.googleapis.com/token', + scopes={ + 'https://www.googleapis.com/auth/calendar': ( + 'See, edit, share, and permanently delete all the' + ' calendars you can access using Google Calendar' + ) + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_1', + client_secret='oauth_client_secret1', + ), + ), + ) + auth_config2 = AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl='https://accounts.google.com/o/oauth2/auth', + tokenUrl='https://oauth2.googleapis.com/token', + scopes={ + 'https://www.googleapis.com/auth/calendar': ( + 'See, edit, share, and permanently delete all the' + ' calendars you can access using Google Calendar' + ) + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_2', + client_secret='oauth_client_secret2', + ), + ), + ) + + auth_response1 = AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl='https://accounts.google.com/o/oauth2/auth', + tokenUrl='https://oauth2.googleapis.com/token', + scopes={ + 'https://www.googleapis.com/auth/calendar': ( + 'See, edit, share, and permanently delete all the' + ' calendars you can access using Google Calendar' + ) + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_1', + client_secret='oauth_client_secret1', + ), + ), + exchanged_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_1', + client_secret='oauth_client_secret1', + access_token='token1', + ), + ), + ) + auth_response2 = AuthConfig( + auth_scheme=OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl='https://accounts.google.com/o/oauth2/auth', + tokenUrl='https://oauth2.googleapis.com/token', + scopes={ + 'https://www.googleapis.com/auth/calendar': ( + 'See, edit, share, and permanently delete all the' + ' calendars you can access using Google Calendar' + ) + }, + ) + ) + ), + raw_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_2', + client_secret='oauth_client_secret2', + ), + ), + exchanged_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='oauth_client_id_2', + client_secret='oauth_client_secret2', + access_token='token2', + ), + ), + ) + + def call_external_api1(tool_context: ToolContext) -> int: + nonlocal function_invoked + function_invoked += 1 + auth_response = tool_context.get_auth_response(auth_config1) + if not auth_response: + tool_context.request_credential(auth_config1) + return + assert auth_response == auth_response1.exchanged_auth_credential + return 1 + + def call_external_api2(tool_context: ToolContext) -> int: + nonlocal function_invoked + function_invoked += 1 + auth_response = tool_context.get_auth_response(auth_config2) + if not auth_response: + tool_context.request_credential(auth_config2) + return + assert auth_response == auth_response2.exchanged_auth_credential + return 2 + + agent = Agent( + name='root_agent', + model=mock_model, + tools=[call_external_api1, call_external_api2], + ) + runner = testing_utils.InMemoryRunner(agent) + runner.run('test') + request_euc_function_call_event = runner.session.events[-3] + function_response1 = types.FunctionResponse( + name=request_euc_function_call_event.content.parts[0].function_call.name, + response=auth_response1.model_dump(), + ) + function_response1.id = request_euc_function_call_event.content.parts[ + 0 + ].function_call.id + + function_response2 = types.FunctionResponse( + name=request_euc_function_call_event.content.parts[1].function_call.name, + response=auth_response2.model_dump(), + ) + function_response2.id = request_euc_function_call_event.content.parts[ + 1 + ].function_call.id + runner.run( + new_message=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response1), + ], + ), + ) + + assert function_invoked == 3 + assert len(mock_model.requests) == 3 + request = mock_model.requests[-1] + content = request.contents[-1] + parts = content.parts + assert len(parts) == 2 + assert parts[0].function_response.name == 'call_external_api1' + assert parts[0].function_response.response == {'result': 1} + assert parts[1].function_response.name == 'call_external_api2' + assert parts[1].function_response.response == {'result': None} + + runner.run( + new_message=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response2), + ], + ), + ) + # assert function_invoked == 4 + assert len(mock_model.requests) == 4 + request = mock_model.requests[-1] + content = request.contents[-1] + parts = content.parts + assert len(parts) == 2 + assert parts[0].function_response.name == 'call_external_api1' + assert parts[0].function_response.response == {'result': None} + assert parts[1].function_response.name == 'call_external_api2' + assert parts[1].function_response.response == {'result': 2}