Copybara import of the project:

--
cef3ca1ed3493eebaeab3e03bdf5e56b35c0b8ef by Lucas Nobre <lucaas.sn@gmail.com>:

feat: Add index tracking to handle parallel tool call using litellm
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/759 from lucasnobre212:fix/issue_484 65e22934bf839f9ea03963b9ec6c23fdce03f59f
PiperOrigin-RevId: 764902433
This commit is contained in:
LucasNobre 2025-05-29 15:10:49 -07:00 committed by Copybara-Service
parent 841e10ae35
commit 05f4834759
2 changed files with 202 additions and 20 deletions

View File

@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
id: Optional[str]
name: Optional[str]
args: Optional[str]
index: Optional[int] = 0
class TextChunk(BaseModel):
@ -386,6 +387,7 @@ def _model_response_to_chunk(
id=tool_call.id,
name=tool_call.function.name,
args=tool_call.function.arguments,
index=tool_call.index,
), finish_reason
if finish_reason and not (
@ -661,9 +663,8 @@ class LiteLlm(BaseLlm):
if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
# Track function calls by index
function_calls = {} # index -> {name, args, id}
completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
@ -672,11 +673,17 @@ class LiteLlm(BaseLlm):
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
index = chunk.index or 0
if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None}
if chunk.name:
function_name += chunk.name
function_calls[index]["name"] += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
function_calls[index]["args"] += chunk.args
function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"]
)
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
@ -693,28 +700,31 @@ class LiteLlm(BaseLlm):
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_id:
if finish_reason == "tool_calls" and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
tool_calls=tool_calls,
)
)
)
function_name = ""
function_args = ""
function_id = None
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)

View File

@ -38,6 +38,7 @@ from litellm.types.utils import Delta
from litellm.types.utils import ModelResponse
from litellm.types.utils import StreamingChoices
import pytest
import json
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
contents=[
@ -170,6 +171,100 @@ STREAMING_MODEL_RESPONSE = [
),
]
MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="call_1",
function=Function(
name="function_1",
arguments='{"arg": "val',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue1"}',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="call_2",
function=Function(
name="function_2",
arguments='{"arg": "val',
),
index=1,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='ue2"}',
),
index=1,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason="tool_calls",
)
]
),
]
@pytest.fixture
def mock_response():
@ -1089,3 +1184,80 @@ async def test_generate_content_async_stream_with_usage_metadata(
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_multiple_function_calls(
mock_completion, lite_llm_instance
):
"""Test handling of multiple function calls with different indices in streaming mode.
This test verifies that:
1. Multiple function calls with different indices are handled correctly
2. Arguments and names are properly accumulated for each function call
3. The final response contains all function calls with correct indices
"""
mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[types.Part.from_text(text="Test multiple function calls")],
)
],
config=types.GenerateContentConfig(
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="function_1",
description="First test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
types.FunctionDeclaration(
name="function_2",
description="Second test function",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"arg": types.Schema(type=types.Type.STRING),
},
),
),
]
)
],
),
)
responses = []
async for response in lite_llm_instance.generate_content_async(
llm_request, stream=True
):
responses.append(response)
# Verify we got the final response with both function calls
assert len(responses) > 0
final_response = responses[-1]
assert final_response.content.role == "model"
assert len(final_response.content.parts) == 2
# Verify first function call
assert final_response.content.parts[0].function_call.name == "function_1"
assert final_response.content.parts[0].function_call.id == "call_1"
assert final_response.content.parts[0].function_call.args == {
"arg": "value1"
}
# Verify second function call
assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "call_2"
assert final_response.content.parts[1].function_call.args == {
"arg": "value2"
}