From b181cbc8bc629d1c9bfd50054e47a0a1b04f7410 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Mon, 2 Jun 2025 10:51:15 -0700 Subject: [PATCH] fix: Handle non-indexed function call chunks with incremental fallback index This is in response to the litellm v1.71.2 + ollama v0.9.0 sending function call chunks with 0 indices across multiple calls and lacking call ids. Solutions introduced: 1. increment fallback index when accumulated arg becomes json parsable. 2. tolerate finish reason == stop when tool calls are present 3. fallback to index when tool call id is None Fixes https://github.com/google/adk-python/issues/294 PiperOrigin-RevId: 766258344 --- pyproject.toml | 2 +- src/google/adk/models/lite_llm.py | 20 ++- tests/unittests/models/test_litellm.py | 172 +++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c66e92..f64c438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ test = [ "anthropic>=0.43.0", # For anthropic model tests "langchain-community>=0.3.17", "langgraph>=0.2.60", # For LangGraphAgent - "litellm>=1.63.11", # For LiteLLM tests + "litellm>=1.71.2", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "pytest-asyncio>=0.25.0", diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index b2d068c..b6c201e 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import base64 import json @@ -669,11 +670,11 @@ class LiteLlm(BaseLlm): aggregated_llm_response = None aggregated_llm_response_with_tool_call = None usage_metadata = None - + fallback_index = 0 for part in self.llm_client.completion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): - index = chunk.index or 0 + index = chunk.index or fallback_index if index not in function_calls: function_calls[index] = {"name": "", "args": "", "id": None} @@ -681,8 +682,17 @@ class LiteLlm(BaseLlm): function_calls[index]["name"] += chunk.name if chunk.args: function_calls[index]["args"] += chunk.args + + # check if args is completed (workaround for improper chunk + # indexing) + try: + json.loads(function_calls[index]["args"]) + fallback_index += 1 + except json.JSONDecodeError: + pass + function_calls[index]["id"] = ( - chunk.id or function_calls[index]["id"] + chunk.id or function_calls[index]["id"] or str(index) ) elif isinstance(chunk, TextChunk): text += chunk.text @@ -700,7 +710,9 @@ class LiteLlm(BaseLlm): total_token_count=chunk.total_tokens, ) - if finish_reason == "tool_calls" and function_calls: + if ( + finish_reason == "tool_calls" or finish_reason == "stop" + ) and function_calls: tool_calls = [] for index, func_data in function_calls.items(): if func_data["id"]: diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index a62f096..1617863 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -290,6 +290,105 @@ def mock_response(): ) +# Test case reflecting litellm v1.71.2, ollama v0.9.0 streaming response +# no tool call ids +# indices all 0 +# finish_reason stop instead of tool_calls +NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [ + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name="function_1", + arguments='{"arg": "val', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name=None, + arguments='ue1"}', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name="function_2", + arguments='{"arg": "val', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason=None, + delta=Delta( + role="assistant", + tool_calls=[ + ChatCompletionDeltaToolCall( + type="function", + id=None, + function=Function( + name=None, + arguments='ue2"}', + ), + index=0, + ) + ], + ), + ) + ] + ), + ModelResponse( + choices=[ + StreamingChoices( + finish_reason="stop", + ) + ] + ), +] + + @pytest.fixture def mock_acompletion(mock_response): return AsyncMock(return_value=mock_response) @@ -1257,3 +1356,76 @@ async def test_generate_content_async_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "call_2" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +async def test_generate_content_async_non_compliant_multiple_function_calls( + mock_completion, lite_llm_instance +): + """Test handling of multiple function calls with same 0 indices in streaming mode. + + This test verifies that: + 1. Multiple function calls with same indices (0) are handled correctly + 2. Arguments and names are properly accumulated for each function call + 3. The final response contains all function calls with correct incremented indices + """ + mock_completion.return_value = NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text="Test multiple function calls")], + ) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="function_1", + description="First test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + types.FunctionDeclaration( + name="function_2", + description="Second test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + ] + ) + ], + ), + ) + + responses = [] + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ): + responses.append(response) + + # Verify we got the final response with both function calls + assert len(responses) > 0 + final_response = responses[-1] + assert final_response.content.role == "model" + assert len(final_response.content.parts) == 2 + + # Verify first function call + assert final_response.content.parts[0].function_call.name == "function_1" + assert final_response.content.parts[0].function_call.id == "0" + assert final_response.content.parts[0].function_call.args == {"arg": "value1"} + + # Verify second function call + assert final_response.content.parts[1].function_call.name == "function_2" + assert final_response.content.parts[1].function_call.id == "1" + assert final_response.content.parts[1].function_call.args == {"arg": "value2"}