From 05f4834759c9b1f0c0af9d89adb7b81ea67d82c8 Mon Sep 17 00:00:00 2001 From: LucasNobre Date: Thu, 29 May 2025 15:10:49 -0700 Subject: [PATCH] Copybara import of the project: -- cef3ca1ed3493eebaeab3e03bdf5e56b35c0b8ef by Lucas Nobre : feat: Add index tracking to handle parallel tool call using litellm COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/759 from lucasnobre212:fix/issue_484 65e22934bf839f9ea03963b9ec6c23fdce03f59f PiperOrigin-RevId: 764902433 --- src/google/adk/models/lite_llm.py | 50 ++++--- tests/unittests/models/test_litellm.py | 172 +++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 20 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ca57a8d..b2d068c 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -62,6 +62,7 @@ class FunctionChunk(BaseModel): id: Optional[str] name: Optional[str] args: Optional[str] + index: Optional[int] = 0 class TextChunk(BaseModel): @@ -386,6 +387,7 @@ def _model_response_to_chunk( id=tool_call.id, name=tool_call.function.name, args=tool_call.function.arguments, + index=tool_call.index, ), finish_reason if finish_reason and not ( @@ -661,9 +663,8 @@ class LiteLlm(BaseLlm): if stream: text = "" - function_name = "" - function_args = "" - function_id = None + # Track function calls by index + function_calls = {} # index -> {name, args, id} completion_args["stream"] = True aggregated_llm_response = None aggregated_llm_response_with_tool_call = None @@ -672,11 +673,17 @@ class LiteLlm(BaseLlm): for part in self.llm_client.completion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): + index = chunk.index or 0 + if index not in function_calls: + function_calls[index] = {"name": "", "args": "", "id": None} + if chunk.name: - function_name += chunk.name + function_calls[index]["name"] += chunk.name if chunk.args: - function_args += chunk.args - function_id = chunk.id or function_id + function_calls[index]["args"] += chunk.args + function_calls[index]["id"] = ( + chunk.id or function_calls[index]["id"] + ) elif isinstance(chunk, TextChunk): text += chunk.text yield _message_to_generate_content_response( @@ -693,28 +700,31 @@ class LiteLlm(BaseLlm): total_token_count=chunk.total_tokens, ) - if finish_reason == "tool_calls" and function_id: + if finish_reason == "tool_calls" and function_calls: + tool_calls = [] + for index, func_data in function_calls.items(): + if func_data["id"]: + tool_calls.append( + ChatCompletionMessageToolCall( + type="function", + id=func_data["id"], + function=Function( + name=func_data["name"], + arguments=func_data["args"], + index=index, + ), + ) + ) aggregated_llm_response_with_tool_call = ( _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", content="", - tool_calls=[ - ChatCompletionMessageToolCall( - type="function", - id=function_id, - function=Function( - name=function_name, - arguments=function_args, - ), - ) - ], + tool_calls=tool_calls, ) ) ) - function_name = "" - function_args = "" - function_id = None + function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( ChatCompletionAssistantMessage(role="assistant", content=text) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 771fd93..57cb0b7 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -38,6 +38,7 @@ from litellm.types.utils import Delta from litellm.types.utils import ModelResponse from litellm.types.utils import StreamingChoices import pytest +import json LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest( contents=[ @@ -170,6 +171,100 @@ STREAMING_MODEL_RESPONSE = [ ), ] +MULTIPLE_FUNCTION_CALLS_STREAM = [ + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id="call_1", + function=Function( + name="function_1", + arguments='{"arg": "val', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name=None, + arguments='ue1"}', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id="call_2", + function=Function( + name="function_2", + arguments='{"arg": "val', + ), + index=1, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name=None, + arguments='ue2"}', + ), + index=1, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason="tool_calls", + ) + ] + ), +] + @pytest.fixture def mock_response(): @@ -1089,3 +1184,80 @@ async def test_generate_content_async_stream_with_usage_metadata( ] == "string" ) + + +@pytest.mark.asyncio +async def test_generate_content_async_multiple_function_calls( + mock_completion, lite_llm_instance +): + """Test handling of multiple function calls with different indices in streaming mode. + + This test verifies that: + 1. Multiple function calls with different indices are handled correctly + 2. Arguments and names are properly accumulated for each function call + 3. The final response contains all function calls with correct indices + """ + mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text="Test multiple function calls")], + ) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="function_1", + description="First test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + types.FunctionDeclaration( + name="function_2", + description="Second test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + ] + ) + ], + ), + ) + + responses = [] + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ): + responses.append(response) + + # Verify we got the final response with both function calls + assert len(responses) > 0 + final_response = responses[-1] + assert final_response.content.role == "model" + assert len(final_response.content.parts) == 2 + + # Verify first function call + assert final_response.content.parts[0].function_call.name == "function_1" + assert final_response.content.parts[0].function_call.id == "call_1" + assert final_response.content.parts[0].function_call.args == { + "arg": "value1" + } + + # Verify second function call + assert final_response.content.parts[1].function_call.name == "function_2" + assert final_response.content.parts[1].function_call.id == "call_2" + assert final_response.content.parts[1].function_call.args == { + "arg": "value2" + }