mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 07:04:51 -06:00
fix: Handle non-indexed function call chunks with incremental fallback index
This is in response to the litellm v1.71.2 + ollama v0.9.0 sending function call chunks with 0 indices across multiple calls and lacking call ids. Solutions introduced: 1. increment fallback index when accumulated arg becomes json parsable. 2. tolerate finish reason == stop when tool calls are present 3. fallback to index when tool call id is None Fixes https://github.com/google/adk-python/issues/294 PiperOrigin-RevId: 766258344
This commit is contained in:
parent
bd588bce50
commit
b181cbc8bc
@ -83,7 +83,7 @@ test = [
|
|||||||
"anthropic>=0.43.0", # For anthropic model tests
|
"anthropic>=0.43.0", # For anthropic model tests
|
||||||
"langchain-community>=0.3.17",
|
"langchain-community>=0.3.17",
|
||||||
"langgraph>=0.2.60", # For LangGraphAgent
|
"langgraph>=0.2.60", # For LangGraphAgent
|
||||||
"litellm>=1.63.11", # For LiteLLM tests
|
"litellm>=1.71.2", # For LiteLLM tests
|
||||||
"llama-index-readers-file>=0.4.0", # For retrieval tests
|
"llama-index-readers-file>=0.4.0", # For retrieval tests
|
||||||
|
|
||||||
"pytest-asyncio>=0.25.0",
|
"pytest-asyncio>=0.25.0",
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
@ -669,11 +670,11 @@ class LiteLlm(BaseLlm):
|
|||||||
aggregated_llm_response = None
|
aggregated_llm_response = None
|
||||||
aggregated_llm_response_with_tool_call = None
|
aggregated_llm_response_with_tool_call = None
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
|
fallback_index = 0
|
||||||
for part in self.llm_client.completion(**completion_args):
|
for part in self.llm_client.completion(**completion_args):
|
||||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||||
if isinstance(chunk, FunctionChunk):
|
if isinstance(chunk, FunctionChunk):
|
||||||
index = chunk.index or 0
|
index = chunk.index or fallback_index
|
||||||
if index not in function_calls:
|
if index not in function_calls:
|
||||||
function_calls[index] = {"name": "", "args": "", "id": None}
|
function_calls[index] = {"name": "", "args": "", "id": None}
|
||||||
|
|
||||||
@ -681,8 +682,17 @@ class LiteLlm(BaseLlm):
|
|||||||
function_calls[index]["name"] += chunk.name
|
function_calls[index]["name"] += chunk.name
|
||||||
if chunk.args:
|
if chunk.args:
|
||||||
function_calls[index]["args"] += chunk.args
|
function_calls[index]["args"] += chunk.args
|
||||||
|
|
||||||
|
# check if args is completed (workaround for improper chunk
|
||||||
|
# indexing)
|
||||||
|
try:
|
||||||
|
json.loads(function_calls[index]["args"])
|
||||||
|
fallback_index += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
function_calls[index]["id"] = (
|
function_calls[index]["id"] = (
|
||||||
chunk.id or function_calls[index]["id"]
|
chunk.id or function_calls[index]["id"] or str(index)
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, TextChunk):
|
elif isinstance(chunk, TextChunk):
|
||||||
text += chunk.text
|
text += chunk.text
|
||||||
@ -700,7 +710,9 @@ class LiteLlm(BaseLlm):
|
|||||||
total_token_count=chunk.total_tokens,
|
total_token_count=chunk.total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if finish_reason == "tool_calls" and function_calls:
|
if (
|
||||||
|
finish_reason == "tool_calls" or finish_reason == "stop"
|
||||||
|
) and function_calls:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for index, func_data in function_calls.items():
|
for index, func_data in function_calls.items():
|
||||||
if func_data["id"]:
|
if func_data["id"]:
|
||||||
|
@ -290,6 +290,105 @@ def mock_response():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test case reflecting litellm v1.71.2, ollama v0.9.0 streaming response
|
||||||
|
# no tool call ids
|
||||||
|
# indices all 0
|
||||||
|
# finish_reason stop instead of tool_calls
|
||||||
|
NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name="function_1",
|
||||||
|
arguments='{"arg": "val',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name=None,
|
||||||
|
arguments='ue1"}',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name="function_2",
|
||||||
|
arguments='{"arg": "val',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name=None,
|
||||||
|
arguments='ue2"}',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_acompletion(mock_response):
|
def mock_acompletion(mock_response):
|
||||||
return AsyncMock(return_value=mock_response)
|
return AsyncMock(return_value=mock_response)
|
||||||
@ -1257,3 +1356,76 @@ async def test_generate_content_async_multiple_function_calls(
|
|||||||
assert final_response.content.parts[1].function_call.name == "function_2"
|
assert final_response.content.parts[1].function_call.name == "function_2"
|
||||||
assert final_response.content.parts[1].function_call.id == "call_2"
|
assert final_response.content.parts[1].function_call.id == "call_2"
|
||||||
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
|
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async_non_compliant_multiple_function_calls(
|
||||||
|
mock_completion, lite_llm_instance
|
||||||
|
):
|
||||||
|
"""Test handling of multiple function calls with same 0 indices in streaming mode.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Multiple function calls with same indices (0) are handled correctly
|
||||||
|
2. Arguments and names are properly accumulated for each function call
|
||||||
|
3. The final response contains all function calls with correct incremented indices
|
||||||
|
"""
|
||||||
|
mock_completion.return_value = NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM
|
||||||
|
|
||||||
|
llm_request = LlmRequest(
|
||||||
|
contents=[
|
||||||
|
types.Content(
|
||||||
|
role="user",
|
||||||
|
parts=[types.Part.from_text(text="Test multiple function calls")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
tools=[
|
||||||
|
types.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name="function_1",
|
||||||
|
description="First test function",
|
||||||
|
parameters=types.Schema(
|
||||||
|
type=types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
"arg": types.Schema(type=types.Type.STRING),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name="function_2",
|
||||||
|
description="Second test function",
|
||||||
|
parameters=types.Schema(
|
||||||
|
type=types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
"arg": types.Schema(type=types.Type.STRING),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
async for response in lite_llm_instance.generate_content_async(
|
||||||
|
llm_request, stream=True
|
||||||
|
):
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Verify we got the final response with both function calls
|
||||||
|
assert len(responses) > 0
|
||||||
|
final_response = responses[-1]
|
||||||
|
assert final_response.content.role == "model"
|
||||||
|
assert len(final_response.content.parts) == 2
|
||||||
|
|
||||||
|
# Verify first function call
|
||||||
|
assert final_response.content.parts[0].function_call.name == "function_1"
|
||||||
|
assert final_response.content.parts[0].function_call.id == "0"
|
||||||
|
assert final_response.content.parts[0].function_call.args == {"arg": "value1"}
|
||||||
|
|
||||||
|
# Verify second function call
|
||||||
|
assert final_response.content.parts[1].function_call.name == "function_2"
|
||||||
|
assert final_response.content.parts[1].function_call.id == "1"
|
||||||
|
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
|
||||||
|
Loading…
Reference in New Issue
Block a user