Add parallel tool call support for litellm (#172)

Co-authored-by: Hangfei Lin <hangfei@google.com>
This commit is contained in:
Selcuk Gun 2025-04-14 17:42:33 -07:00 committed by GitHub
parent 4e8b944e09
commit 9a44831a08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 32 deletions

View File

@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str:
def _content_to_message_param( def _content_to_message_param(
content: types.Content, content: types.Content,
) -> Message: ) -> Union[Message, list[Message]]:
"""Converts a types.Content to a litellm Message. """Converts a types.Content to a litellm Message or list of Messages.
Handles multipart function responses by returning a list of
ChatCompletionToolMessage objects if multiple function_response parts exist.
Args: Args:
content: The content to convert. content: The content to convert.
Returns: Returns:
The litellm Message. A litellm Message, a list of litellm Messages.
""" """
if content.parts and content.parts[0].function_response: tool_messages = []
return ChatCompletionToolMessage( for part in content.parts:
role="tool", if part.function_response:
tool_call_id=content.parts[0].function_response.id, tool_messages.append(
content=_safe_json_serialize( ChatCompletionToolMessage(
content.parts[0].function_response.response role="tool",
), tool_call_id=part.function_response.id,
) content=_safe_json_serialize(part.function_response.response),
)
)
if tool_messages:
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
# Handle user or assistant messages
role = _to_litellm_role(content.role) role = _to_litellm_role(content.role)
message_content = _get_content(content.parts) or None
if role == "user": if role == "user":
return ChatCompletionUserMessage( return ChatCompletionUserMessage(role="user", content=message_content)
role="user", content=_get_content(content.parts) else: # assistant/model
) tool_calls = []
else: content_present = False
for part in content.parts:
if part.function_call:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
)
elif part.text or part.inline_data:
content_present = True
tool_calls = [ final_content = message_content if content_present else None
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
for part in content.parts
if part.function_call
]
return ChatCompletionAssistantMessage( return ChatCompletionAssistantMessage(
role=role, role=role,
content=_get_content(content.parts) or None, content=final_content,
tool_calls=tool_calls or None, tool_calls=tool_calls or None,
) )
@ -437,10 +449,13 @@ def _get_completion_inputs(
Returns: Returns:
The litellm inputs (message list and tool dictionary). The litellm inputs (message list and tool dictionary).
""" """
messages = [ messages = []
_content_to_message_param(content) for content in llm_request.contents or []:
for content in llm_request.contents or [] message_param_or_list = _content_to_message_param(content)
] if isinstance(message_param_or_list, list):
messages.extend(message_param_or_list)
elif message_param_or_list: # Ensure it's not None before appending
messages.append(message_param_or_list)
if llm_request.config.system_instruction: if llm_request.config.system_instruction:
messages.insert( messages.insert(

View File

@ -515,6 +515,36 @@ def test_content_to_message_param_user_message():
assert message["content"] == "Test prompt" assert message["content"] == "Test prompt"
def test_content_to_message_param_multi_part_function_response():
part1 = types.Part.from_function_response(
name="function_one",
response={"result": "result_one"},
)
part1.function_response.id = "tool_call_1"
part2 = types.Part.from_function_response(
name="function_two",
response={"value": 123},
)
part2.function_response.id = "tool_call_2"
content = types.Content(
role="tool",
parts=[part1, part2],
)
messages = _content_to_message_param(content)
assert isinstance(messages, list)
assert len(messages) == 2
assert messages[0]["role"] == "tool"
assert messages[0]["tool_call_id"] == "tool_call_1"
assert messages[0]["content"] == '{"result": "result_one"}'
assert messages[1]["role"] == "tool"
assert messages[1]["tool_call_id"] == "tool_call_2"
assert messages[1]["content"] == '{"value": 123}'
def test_content_to_message_param_assistant_message(): def test_content_to_message_param_assistant_message():
content = types.Content( content = types.Content(
role="assistant", parts=[types.Part.from_text(text="Test response")] role="assistant", parts=[types.Part.from_text(text="Test response")]