mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 12:12:56 -06:00
Copybara import of the project:
-- cef3ca1ed3493eebaeab3e03bdf5e56b35c0b8ef by Lucas Nobre <lucaas.sn@gmail.com>: feat: Add index tracking to handle parallel tool call using litellm COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/759 from lucasnobre212:fix/issue_484 65e22934bf839f9ea03963b9ec6c23fdce03f59f PiperOrigin-RevId: 764902433
This commit is contained in:
parent
841e10ae35
commit
05f4834759
@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
|
|||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
args: Optional[str]
|
args: Optional[str]
|
||||||
|
index: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
class TextChunk(BaseModel):
|
class TextChunk(BaseModel):
|
||||||
@ -386,6 +387,7 @@ def _model_response_to_chunk(
|
|||||||
id=tool_call.id,
|
id=tool_call.id,
|
||||||
name=tool_call.function.name,
|
name=tool_call.function.name,
|
||||||
args=tool_call.function.arguments,
|
args=tool_call.function.arguments,
|
||||||
|
index=tool_call.index,
|
||||||
), finish_reason
|
), finish_reason
|
||||||
|
|
||||||
if finish_reason and not (
|
if finish_reason and not (
|
||||||
@ -661,9 +663,8 @@ class LiteLlm(BaseLlm):
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
text = ""
|
text = ""
|
||||||
function_name = ""
|
# Track function calls by index
|
||||||
function_args = ""
|
function_calls = {} # index -> {name, args, id}
|
||||||
function_id = None
|
|
||||||
completion_args["stream"] = True
|
completion_args["stream"] = True
|
||||||
aggregated_llm_response = None
|
aggregated_llm_response = None
|
||||||
aggregated_llm_response_with_tool_call = None
|
aggregated_llm_response_with_tool_call = None
|
||||||
@ -672,11 +673,17 @@ class LiteLlm(BaseLlm):
|
|||||||
for part in self.llm_client.completion(**completion_args):
|
for part in self.llm_client.completion(**completion_args):
|
||||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||||
if isinstance(chunk, FunctionChunk):
|
if isinstance(chunk, FunctionChunk):
|
||||||
|
index = chunk.index or 0
|
||||||
|
if index not in function_calls:
|
||||||
|
function_calls[index] = {"name": "", "args": "", "id": None}
|
||||||
|
|
||||||
if chunk.name:
|
if chunk.name:
|
||||||
function_name += chunk.name
|
function_calls[index]["name"] += chunk.name
|
||||||
if chunk.args:
|
if chunk.args:
|
||||||
function_args += chunk.args
|
function_calls[index]["args"] += chunk.args
|
||||||
function_id = chunk.id or function_id
|
function_calls[index]["id"] = (
|
||||||
|
chunk.id or function_calls[index]["id"]
|
||||||
|
)
|
||||||
elif isinstance(chunk, TextChunk):
|
elif isinstance(chunk, TextChunk):
|
||||||
text += chunk.text
|
text += chunk.text
|
||||||
yield _message_to_generate_content_response(
|
yield _message_to_generate_content_response(
|
||||||
@ -693,28 +700,31 @@ class LiteLlm(BaseLlm):
|
|||||||
total_token_count=chunk.total_tokens,
|
total_token_count=chunk.total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if finish_reason == "tool_calls" and function_id:
|
if finish_reason == "tool_calls" and function_calls:
|
||||||
|
tool_calls = []
|
||||||
|
for index, func_data in function_calls.items():
|
||||||
|
if func_data["id"]:
|
||||||
|
tool_calls.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
type="function",
|
||||||
|
id=func_data["id"],
|
||||||
|
function=Function(
|
||||||
|
name=func_data["name"],
|
||||||
|
arguments=func_data["args"],
|
||||||
|
index=index,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
aggregated_llm_response_with_tool_call = (
|
aggregated_llm_response_with_tool_call = (
|
||||||
_message_to_generate_content_response(
|
_message_to_generate_content_response(
|
||||||
ChatCompletionAssistantMessage(
|
ChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=tool_calls,
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
type="function",
|
|
||||||
id=function_id,
|
|
||||||
function=Function(
|
|
||||||
name=function_name,
|
|
||||||
arguments=function_args,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
function_name = ""
|
function_calls.clear()
|
||||||
function_args = ""
|
|
||||||
function_id = None
|
|
||||||
elif finish_reason == "stop" and text:
|
elif finish_reason == "stop" and text:
|
||||||
aggregated_llm_response = _message_to_generate_content_response(
|
aggregated_llm_response = _message_to_generate_content_response(
|
||||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||||
|
@ -38,6 +38,7 @@ from litellm.types.utils import Delta
|
|||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
from litellm.types.utils import StreamingChoices
|
from litellm.types.utils import StreamingChoices
|
||||||
import pytest
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
|
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
|
||||||
contents=[
|
contents=[
|
||||||
@ -170,6 +171,100 @@ STREAMING_MODEL_RESPONSE = [
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
MULTIPLE_FUNCTION_CALLS_STREAM = [
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id="call_1",
|
||||||
|
function=Function(
|
||||||
|
name="function_1",
|
||||||
|
arguments='{"arg": "val',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name=None,
|
||||||
|
arguments='ue1"}',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id="call_2",
|
||||||
|
function=Function(
|
||||||
|
name="function_2",
|
||||||
|
arguments='{"arg": "val',
|
||||||
|
),
|
||||||
|
index=1,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
delta=Delta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionDeltaToolCall(
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
function=Function(
|
||||||
|
name=None,
|
||||||
|
arguments='ue2"}',
|
||||||
|
),
|
||||||
|
index=1,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_response():
|
def mock_response():
|
||||||
@ -1089,3 +1184,80 @@ async def test_generate_content_async_stream_with_usage_metadata(
|
|||||||
]
|
]
|
||||||
== "string"
|
== "string"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async_multiple_function_calls(
|
||||||
|
mock_completion, lite_llm_instance
|
||||||
|
):
|
||||||
|
"""Test handling of multiple function calls with different indices in streaming mode.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Multiple function calls with different indices are handled correctly
|
||||||
|
2. Arguments and names are properly accumulated for each function call
|
||||||
|
3. The final response contains all function calls with correct indices
|
||||||
|
"""
|
||||||
|
mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM
|
||||||
|
|
||||||
|
llm_request = LlmRequest(
|
||||||
|
contents=[
|
||||||
|
types.Content(
|
||||||
|
role="user",
|
||||||
|
parts=[types.Part.from_text(text="Test multiple function calls")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
tools=[
|
||||||
|
types.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name="function_1",
|
||||||
|
description="First test function",
|
||||||
|
parameters=types.Schema(
|
||||||
|
type=types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
"arg": types.Schema(type=types.Type.STRING),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name="function_2",
|
||||||
|
description="Second test function",
|
||||||
|
parameters=types.Schema(
|
||||||
|
type=types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
"arg": types.Schema(type=types.Type.STRING),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
async for response in lite_llm_instance.generate_content_async(
|
||||||
|
llm_request, stream=True
|
||||||
|
):
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Verify we got the final response with both function calls
|
||||||
|
assert len(responses) > 0
|
||||||
|
final_response = responses[-1]
|
||||||
|
assert final_response.content.role == "model"
|
||||||
|
assert len(final_response.content.parts) == 2
|
||||||
|
|
||||||
|
# Verify first function call
|
||||||
|
assert final_response.content.parts[0].function_call.name == "function_1"
|
||||||
|
assert final_response.content.parts[0].function_call.id == "call_1"
|
||||||
|
assert final_response.content.parts[0].function_call.args == {
|
||||||
|
"arg": "value1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Verify second function call
|
||||||
|
assert final_response.content.parts[1].function_call.name == "function_2"
|
||||||
|
assert final_response.content.parts[1].function_call.id == "call_2"
|
||||||
|
assert final_response.content.parts[1].function_call.args == {
|
||||||
|
"arg": "value2"
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user