Copybara import of the project:

--
cef3ca1ed3493eebaeab3e03bdf5e56b35c0b8ef by Lucas Nobre <lucaas.sn@gmail.com>:

feat: Add index tracking to handle parallel tool call using litellm
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/759 from lucasnobre212:fix/issue_484 65e22934bf839f9ea03963b9ec6c23fdce03f59f
PiperOrigin-RevId: 764902433
This commit is contained in:
LucasNobre
2025-05-29 15:10:49 -07:00
committed by Copybara-Service
parent 841e10ae35
commit 05f4834759
2 changed files with 202 additions and 20 deletions

View File

@@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
id: Optional[str]
name: Optional[str]
args: Optional[str]
index: Optional[int] = 0
class TextChunk(BaseModel):
@@ -386,6 +387,7 @@ def _model_response_to_chunk(
id=tool_call.id,
name=tool_call.function.name,
args=tool_call.function.arguments,
index=tool_call.index,
), finish_reason
if finish_reason and not (
@@ -661,9 +663,8 @@ class LiteLlm(BaseLlm):
if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
# Track function calls by index
function_calls = {} # index -> {name, args, id}
completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
@@ -672,11 +673,17 @@ class LiteLlm(BaseLlm):
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
index = chunk.index or 0
if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None}
if chunk.name:
function_name += chunk.name
function_calls[index]["name"] += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
function_calls[index]["args"] += chunk.args
function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"]
)
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
@@ -693,28 +700,31 @@ class LiteLlm(BaseLlm):
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_id:
if finish_reason == "tool_calls" and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
tool_calls=tool_calls,
)
)
)
function_name = ""
function_args = ""
function_id = None
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)