Add parallel tool call support for litellm (#172)

Co-authored-by: Hangfei Lin <hangfei@google.com>
This commit is contained in:
Selcuk Gun
2025-04-14 17:42:33 -07:00
committed by GitHub
parent 4e8b944e09
commit 9a44831a08
2 changed files with 77 additions and 32 deletions

View File

@@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str:
def _content_to_message_param(
content: types.Content,
) -> Message:
"""Converts a types.Content to a litellm Message.
) -> Union[Message, list[Message]]:
"""Converts a types.Content to a litellm Message or list of Messages.
Handles multipart function responses by returning a list of
ChatCompletionToolMessage objects if multiple function_response parts exist.
Args:
content: The content to convert.
Returns:
The litellm Message.
A litellm Message, a list of litellm Messages.
"""
if content.parts and content.parts[0].function_response:
return ChatCompletionToolMessage(
role="tool",
tool_call_id=content.parts[0].function_response.id,
content=_safe_json_serialize(
content.parts[0].function_response.response
),
)
tool_messages = []
for part in content.parts:
if part.function_response:
tool_messages.append(
ChatCompletionToolMessage(
role="tool",
tool_call_id=part.function_response.id,
content=_safe_json_serialize(part.function_response.response),
)
)
if tool_messages:
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
# Handle user or assistant messages
role = _to_litellm_role(content.role)
message_content = _get_content(content.parts) or None
if role == "user":
return ChatCompletionUserMessage(
role="user", content=_get_content(content.parts)
)
else:
return ChatCompletionUserMessage(role="user", content=message_content)
else: # assistant/model
tool_calls = []
content_present = False
for part in content.parts:
if part.function_call:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
)
elif part.text or part.inline_data:
content_present = True
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
for part in content.parts
if part.function_call
]
final_content = message_content if content_present else None
return ChatCompletionAssistantMessage(
role=role,
content=_get_content(content.parts) or None,
content=final_content,
tool_calls=tool_calls or None,
)
@@ -437,10 +449,13 @@ def _get_completion_inputs(
Returns:
The litellm inputs (message list and tool dictionary).
"""
messages = [
_content_to_message_param(content)
for content in llm_request.contents or []
]
messages = []
for content in llm_request.contents or []:
message_param_or_list = _content_to_message_param(content)
if isinstance(message_param_or_list, list):
messages.extend(message_param_or_list)
elif message_param_or_list: # Ensure it's not None before appending
messages.append(message_param_or_list)
if llm_request.config.system_instruction:
messages.insert(