mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-19 11:52:19 -06:00
Copybara import of the project:
-- cef3ca1ed3493eebaeab3e03bdf5e56b35c0b8ef by Lucas Nobre <lucaas.sn@gmail.com>: feat: Add index tracking to handle parallel tool call using litellm COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/759 from lucasnobre212:fix/issue_484 65e22934bf839f9ea03963b9ec6c23fdce03f59f PiperOrigin-RevId: 764902433
This commit is contained in:
committed by
Copybara-Service
parent
841e10ae35
commit
05f4834759
@@ -62,6 +62,7 @@ class FunctionChunk(BaseModel):
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
args: Optional[str]
|
||||
index: Optional[int] = 0
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
@@ -386,6 +387,7 @@ def _model_response_to_chunk(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
index=tool_call.index,
|
||||
), finish_reason
|
||||
|
||||
if finish_reason and not (
|
||||
@@ -661,9 +663,8 @@ class LiteLlm(BaseLlm):
|
||||
|
||||
if stream:
|
||||
text = ""
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
# Track function calls by index
|
||||
function_calls = {} # index -> {name, args, id}
|
||||
completion_args["stream"] = True
|
||||
aggregated_llm_response = None
|
||||
aggregated_llm_response_with_tool_call = None
|
||||
@@ -672,11 +673,17 @@ class LiteLlm(BaseLlm):
|
||||
for part in self.llm_client.completion(**completion_args):
|
||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||
if isinstance(chunk, FunctionChunk):
|
||||
index = chunk.index or 0
|
||||
if index not in function_calls:
|
||||
function_calls[index] = {"name": "", "args": "", "id": None}
|
||||
|
||||
if chunk.name:
|
||||
function_name += chunk.name
|
||||
function_calls[index]["name"] += chunk.name
|
||||
if chunk.args:
|
||||
function_args += chunk.args
|
||||
function_id = chunk.id or function_id
|
||||
function_calls[index]["args"] += chunk.args
|
||||
function_calls[index]["id"] = (
|
||||
chunk.id or function_calls[index]["id"]
|
||||
)
|
||||
elif isinstance(chunk, TextChunk):
|
||||
text += chunk.text
|
||||
yield _message_to_generate_content_response(
|
||||
@@ -693,28 +700,31 @@ class LiteLlm(BaseLlm):
|
||||
total_token_count=chunk.total_tokens,
|
||||
)
|
||||
|
||||
if finish_reason == "tool_calls" and function_id:
|
||||
if finish_reason == "tool_calls" and function_calls:
|
||||
tool_calls = []
|
||||
for index, func_data in function_calls.items():
|
||||
if func_data["id"]:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=func_data["id"],
|
||||
function=Function(
|
||||
name=func_data["name"],
|
||||
arguments=func_data["args"],
|
||||
index=index,
|
||||
),
|
||||
)
|
||||
)
|
||||
aggregated_llm_response_with_tool_call = (
|
||||
_message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
)
|
||||
)
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
function_calls.clear()
|
||||
elif finish_reason == "stop" and text:
|
||||
aggregated_llm_response = _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||
|
||||
Reference in New Issue
Block a user