fix: Handle non-indexed function call chunks with incremental fallback index

This is in response to the litellm v1.71.2 + ollama v0.9.0 sending function call chunks with 0 indices across multiple calls and lacking call ids.

Solutions introduced:
1. increment fallback index when accumulated arg becomes json parsable.
2. tolerate finish reason == stop when tool calls are present
3. fallback to index when tool call id is None

Fixes https://github.com/google/adk-python/issues/294

PiperOrigin-RevId: 766258344
This commit is contained in:
Selcuk Gun 2025-06-02 10:51:15 -07:00 committed by Copybara-Service
parent bd588bce50
commit b181cbc8bc
3 changed files with 189 additions and 5 deletions

View File

@ -83,7 +83,7 @@ test = [
"anthropic>=0.43.0", # For anthropic model tests "anthropic>=0.43.0", # For anthropic model tests
"langchain-community>=0.3.17", "langchain-community>=0.3.17",
"langgraph>=0.2.60", # For LangGraphAgent "langgraph>=0.2.60", # For LangGraphAgent
"litellm>=1.63.11", # For LiteLLM tests "litellm>=1.71.2", # For LiteLLM tests
"llama-index-readers-file>=0.4.0", # For retrieval tests "llama-index-readers-file>=0.4.0", # For retrieval tests
"pytest-asyncio>=0.25.0", "pytest-asyncio>=0.25.0",

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import base64 import base64
import json import json
@ -669,11 +670,11 @@ class LiteLlm(BaseLlm):
aggregated_llm_response = None aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None aggregated_llm_response_with_tool_call = None
usage_metadata = None usage_metadata = None
fallback_index = 0
for part in self.llm_client.completion(**completion_args): for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part): for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk): if isinstance(chunk, FunctionChunk):
index = chunk.index or 0 index = chunk.index or fallback_index
if index not in function_calls: if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None} function_calls[index] = {"name": "", "args": "", "id": None}
@ -681,8 +682,17 @@ class LiteLlm(BaseLlm):
function_calls[index]["name"] += chunk.name function_calls[index]["name"] += chunk.name
if chunk.args: if chunk.args:
function_calls[index]["args"] += chunk.args function_calls[index]["args"] += chunk.args
# check if args is completed (workaround for improper chunk
# indexing)
try:
json.loads(function_calls[index]["args"])
fallback_index += 1
except json.JSONDecodeError:
pass
function_calls[index]["id"] = ( function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"] chunk.id or function_calls[index]["id"] or str(index)
) )
elif isinstance(chunk, TextChunk): elif isinstance(chunk, TextChunk):
text += chunk.text text += chunk.text
@ -700,7 +710,9 @@ class LiteLlm(BaseLlm):
total_token_count=chunk.total_tokens, total_token_count=chunk.total_tokens,
) )
if finish_reason == "tool_calls" and function_calls: if (
finish_reason == "tool_calls" or finish_reason == "stop"
) and function_calls:
tool_calls = [] tool_calls = []
for index, func_data in function_calls.items(): for index, func_data in function_calls.items():
if func_data["id"]: if func_data["id"]:

View File

@ -290,6 +290,105 @@ def mock_response():
) )
# Test case reflecting litellm v1.71.2, ollama v0.9.0 streaming response
# no tool call ids
# indices all 0
# finish_reason stop instead of tool_calls
NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name="function_1",
arguments='{"arg": "val',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue1"}',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name="function_2",
arguments='{"arg": "val',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue2"}',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason="stop",
)
]
),
]
@pytest.fixture @pytest.fixture
def mock_acompletion(mock_response): def mock_acompletion(mock_response):
return AsyncMock(return_value=mock_response) return AsyncMock(return_value=mock_response)
@ -1257,3 +1356,76 @@ async def test_generate_content_async_multiple_function_calls(
assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "call_2" assert final_response.content.parts[1].function_call.id == "call_2"
assert final_response.content.parts[1].function_call.args == {"arg": "value2"} assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
@pytest.mark.asyncio
async def test_generate_content_async_non_compliant_multiple_function_calls(
mock_completion, lite_llm_instance
):
"""Test handling of multiple function calls with same 0 indices in streaming mode.
This test verifies that:
1. Multiple function calls with same indices (0) are handled correctly
2. Arguments and names are properly accumulated for each function call
3. The final response contains all function calls with correct incremented indices
"""
mock_completion.return_value = NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[types.Part.from_text(text="Test multiple function calls")],
)
],
config=types.GenerateContentConfig(
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="function_1",
description="First test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
types.FunctionDeclaration(
name="function_2",
description="Second test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
]
)
],
),
)
responses = []
async for response in lite_llm_instance.generate_content_async(
llm_request, stream=True
):
responses.append(response)
# Verify we got the final response with both function calls
assert len(responses) > 0
final_response = responses[-1]
assert final_response.content.role == "model"
assert len(final_response.content.parts) == 2
# Verify first function call
assert final_response.content.parts[0].function_call.name == "function_1"
assert final_response.content.parts[0].function_call.id == "0"
assert final_response.content.parts[0].function_call.args == {"arg": "value1"}
# Verify second function call
assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "1"
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}