mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Add parallel tool call support for litellm (#172)
Co-authored-by: Hangfei Lin <hangfei@google.com>
This commit is contained in:
parent
4e8b944e09
commit
9a44831a08
@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str:
|
|||||||
|
|
||||||
def _content_to_message_param(
|
def _content_to_message_param(
|
||||||
content: types.Content,
|
content: types.Content,
|
||||||
) -> Message:
|
) -> Union[Message, list[Message]]:
|
||||||
"""Converts a types.Content to a litellm Message.
|
"""Converts a types.Content to a litellm Message or list of Messages.
|
||||||
|
|
||||||
|
Handles multipart function responses by returning a list of
|
||||||
|
ChatCompletionToolMessage objects if multiple function_response parts exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The content to convert.
|
content: The content to convert.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The litellm Message.
|
A litellm Message, a list of litellm Messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if content.parts and content.parts[0].function_response:
|
tool_messages = []
|
||||||
return ChatCompletionToolMessage(
|
for part in content.parts:
|
||||||
role="tool",
|
if part.function_response:
|
||||||
tool_call_id=content.parts[0].function_response.id,
|
tool_messages.append(
|
||||||
content=_safe_json_serialize(
|
ChatCompletionToolMessage(
|
||||||
content.parts[0].function_response.response
|
role="tool",
|
||||||
),
|
tool_call_id=part.function_response.id,
|
||||||
)
|
content=_safe_json_serialize(part.function_response.response),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if tool_messages:
|
||||||
|
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
|
||||||
|
|
||||||
|
# Handle user or assistant messages
|
||||||
role = _to_litellm_role(content.role)
|
role = _to_litellm_role(content.role)
|
||||||
|
message_content = _get_content(content.parts) or None
|
||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return ChatCompletionUserMessage(
|
return ChatCompletionUserMessage(role="user", content=message_content)
|
||||||
role="user", content=_get_content(content.parts)
|
else: # assistant/model
|
||||||
)
|
tool_calls = []
|
||||||
else:
|
content_present = False
|
||||||
|
for part in content.parts:
|
||||||
|
if part.function_call:
|
||||||
|
tool_calls.append(
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
type="function",
|
||||||
|
id=part.function_call.id,
|
||||||
|
function=Function(
|
||||||
|
name=part.function_call.name,
|
||||||
|
arguments=part.function_call.args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif part.text or part.inline_data:
|
||||||
|
content_present = True
|
||||||
|
|
||||||
tool_calls = [
|
final_content = message_content if content_present else None
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
type="function",
|
|
||||||
id=part.function_call.id,
|
|
||||||
function=Function(
|
|
||||||
name=part.function_call.name,
|
|
||||||
arguments=part.function_call.args,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for part in content.parts
|
|
||||||
if part.function_call
|
|
||||||
]
|
|
||||||
|
|
||||||
return ChatCompletionAssistantMessage(
|
return ChatCompletionAssistantMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content=_get_content(content.parts) or None,
|
content=final_content,
|
||||||
tool_calls=tool_calls or None,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -437,10 +449,13 @@ def _get_completion_inputs(
|
|||||||
Returns:
|
Returns:
|
||||||
The litellm inputs (message list and tool dictionary).
|
The litellm inputs (message list and tool dictionary).
|
||||||
"""
|
"""
|
||||||
messages = [
|
messages = []
|
||||||
_content_to_message_param(content)
|
for content in llm_request.contents or []:
|
||||||
for content in llm_request.contents or []
|
message_param_or_list = _content_to_message_param(content)
|
||||||
]
|
if isinstance(message_param_or_list, list):
|
||||||
|
messages.extend(message_param_or_list)
|
||||||
|
elif message_param_or_list: # Ensure it's not None before appending
|
||||||
|
messages.append(message_param_or_list)
|
||||||
|
|
||||||
if llm_request.config.system_instruction:
|
if llm_request.config.system_instruction:
|
||||||
messages.insert(
|
messages.insert(
|
||||||
|
@ -515,6 +515,36 @@ def test_content_to_message_param_user_message():
|
|||||||
assert message["content"] == "Test prompt"
|
assert message["content"] == "Test prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_to_message_param_multi_part_function_response():
|
||||||
|
part1 = types.Part.from_function_response(
|
||||||
|
name="function_one",
|
||||||
|
response={"result": "result_one"},
|
||||||
|
)
|
||||||
|
part1.function_response.id = "tool_call_1"
|
||||||
|
|
||||||
|
part2 = types.Part.from_function_response(
|
||||||
|
name="function_two",
|
||||||
|
response={"value": 123},
|
||||||
|
)
|
||||||
|
part2.function_response.id = "tool_call_2"
|
||||||
|
|
||||||
|
content = types.Content(
|
||||||
|
role="tool",
|
||||||
|
parts=[part1, part2],
|
||||||
|
)
|
||||||
|
messages = _content_to_message_param(content)
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
assert len(messages) == 2
|
||||||
|
|
||||||
|
assert messages[0]["role"] == "tool"
|
||||||
|
assert messages[0]["tool_call_id"] == "tool_call_1"
|
||||||
|
assert messages[0]["content"] == '{"result": "result_one"}'
|
||||||
|
|
||||||
|
assert messages[1]["role"] == "tool"
|
||||||
|
assert messages[1]["tool_call_id"] == "tool_call_2"
|
||||||
|
assert messages[1]["content"] == '{"value": 123}'
|
||||||
|
|
||||||
|
|
||||||
def test_content_to_message_param_assistant_message():
|
def test_content_to_message_param_assistant_message():
|
||||||
content = types.Content(
|
content = types.Content(
|
||||||
role="assistant", parts=[types.Part.from_text(text="Test response")]
|
role="assistant", parts=[types.Part.from_text(text="Test response")]
|
||||||
|
Loading…
Reference in New Issue
Block a user