fix: Handle non-indexed function call chunks with incremental fallback index

This is in response to the litellm v1.71.2 + ollama v0.9.0 sending function call chunks with 0 indices across multiple calls and lacking call ids.

Solutions introduced:
1. increment fallback index when accumulated arg becomes json parsable.
2. tolerate finish reason == stop when tool calls are present
3. fallback to index when tool call id is None

Fixes https://github.com/google/adk-python/issues/294

PiperOrigin-RevId: 766258344
This commit is contained in:
Selcuk Gun
2025-06-02 10:51:15 -07:00
committed by Copybara-Service
parent bd588bce50
commit b181cbc8bc
3 changed files with 189 additions and 5 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import json
@@ -669,11 +670,11 @@ class LiteLlm(BaseLlm):
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
usage_metadata = None
fallback_index = 0
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
index = chunk.index or 0
index = chunk.index or fallback_index
if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None}
@@ -681,8 +682,17 @@ class LiteLlm(BaseLlm):
function_calls[index]["name"] += chunk.name
if chunk.args:
function_calls[index]["args"] += chunk.args
# check if args is completed (workaround for improper chunk
# indexing)
try:
json.loads(function_calls[index]["args"])
fallback_index += 1
except json.JSONDecodeError:
pass
function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"]
chunk.id or function_calls[index]["id"] or str(index)
)
elif isinstance(chunk, TextChunk):
text += chunk.text
@@ -700,7 +710,9 @@ class LiteLlm(BaseLlm):
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_calls:
if (
finish_reason == "tool_calls" or finish_reason == "stop"
) and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]: