mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Add parallel tool call support for litellm (#172)
Co-authored-by: Hangfei Lin <hangfei@google.com>
This commit is contained in:
parent
4e8b944e09
commit
9a44831a08
@ -136,34 +136,44 @@ def _safe_json_serialize(obj) -> str:
|
||||
|
||||
def _content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> Message:
|
||||
"""Converts a types.Content to a litellm Message.
|
||||
) -> Union[Message, list[Message]]:
|
||||
"""Converts a types.Content to a litellm Message or list of Messages.
|
||||
|
||||
Handles multipart function responses by returning a list of
|
||||
ChatCompletionToolMessage objects if multiple function_response parts exist.
|
||||
|
||||
Args:
|
||||
content: The content to convert.
|
||||
|
||||
Returns:
|
||||
The litellm Message.
|
||||
A litellm Message, a list of litellm Messages.
|
||||
"""
|
||||
|
||||
if content.parts and content.parts[0].function_response:
|
||||
return ChatCompletionToolMessage(
|
||||
tool_messages = []
|
||||
for part in content.parts:
|
||||
if part.function_response:
|
||||
tool_messages.append(
|
||||
ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content.parts[0].function_response.id,
|
||||
content=_safe_json_serialize(
|
||||
content.parts[0].function_response.response
|
||||
),
|
||||
tool_call_id=part.function_response.id,
|
||||
content=_safe_json_serialize(part.function_response.response),
|
||||
)
|
||||
)
|
||||
if tool_messages:
|
||||
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
|
||||
|
||||
# Handle user or assistant messages
|
||||
role = _to_litellm_role(content.role)
|
||||
message_content = _get_content(content.parts) or None
|
||||
|
||||
if role == "user":
|
||||
return ChatCompletionUserMessage(
|
||||
role="user", content=_get_content(content.parts)
|
||||
)
|
||||
else:
|
||||
|
||||
tool_calls = [
|
||||
return ChatCompletionUserMessage(role="user", content=message_content)
|
||||
else: # assistant/model
|
||||
tool_calls = []
|
||||
content_present = False
|
||||
for part in content.parts:
|
||||
if part.function_call:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=part.function_call.id,
|
||||
@ -172,13 +182,15 @@ def _content_to_message_param(
|
||||
arguments=part.function_call.args,
|
||||
),
|
||||
)
|
||||
for part in content.parts
|
||||
if part.function_call
|
||||
]
|
||||
)
|
||||
elif part.text or part.inline_data:
|
||||
content_present = True
|
||||
|
||||
final_content = message_content if content_present else None
|
||||
|
||||
return ChatCompletionAssistantMessage(
|
||||
role=role,
|
||||
content=_get_content(content.parts) or None,
|
||||
content=final_content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
@ -437,10 +449,13 @@ def _get_completion_inputs(
|
||||
Returns:
|
||||
The litellm inputs (message list and tool dictionary).
|
||||
"""
|
||||
messages = [
|
||||
_content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
messages = []
|
||||
for content in llm_request.contents or []:
|
||||
message_param_or_list = _content_to_message_param(content)
|
||||
if isinstance(message_param_or_list, list):
|
||||
messages.extend(message_param_or_list)
|
||||
elif message_param_or_list: # Ensure it's not None before appending
|
||||
messages.append(message_param_or_list)
|
||||
|
||||
if llm_request.config.system_instruction:
|
||||
messages.insert(
|
||||
|
@ -515,6 +515,36 @@ def test_content_to_message_param_user_message():
|
||||
assert message["content"] == "Test prompt"
|
||||
|
||||
|
||||
def test_content_to_message_param_multi_part_function_response():
|
||||
part1 = types.Part.from_function_response(
|
||||
name="function_one",
|
||||
response={"result": "result_one"},
|
||||
)
|
||||
part1.function_response.id = "tool_call_1"
|
||||
|
||||
part2 = types.Part.from_function_response(
|
||||
name="function_two",
|
||||
response={"value": 123},
|
||||
)
|
||||
part2.function_response.id = "tool_call_2"
|
||||
|
||||
content = types.Content(
|
||||
role="tool",
|
||||
parts=[part1, part2],
|
||||
)
|
||||
messages = _content_to_message_param(content)
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) == 2
|
||||
|
||||
assert messages[0]["role"] == "tool"
|
||||
assert messages[0]["tool_call_id"] == "tool_call_1"
|
||||
assert messages[0]["content"] == '{"result": "result_one"}'
|
||||
|
||||
assert messages[1]["role"] == "tool"
|
||||
assert messages[1]["tool_call_id"] == "tool_call_2"
|
||||
assert messages[1]["content"] == '{"value": 123}'
|
||||
|
||||
|
||||
def test_content_to_message_param_assistant_message():
|
||||
content = types.Content(
|
||||
role="assistant", parts=[types.Part.from_text(text="Test response")]
|
||||
|
Loading…
Reference in New Issue
Block a user