diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index acdaa55..aeb4107 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -136,49 +136,61 @@ def _safe_json_serialize(obj) -> str: def _content_to_message_param( content: types.Content, -) -> Message: - """Converts a types.Content to a litellm Message. +) -> Union[Message, list[Message]]: + """Converts a types.Content to a litellm Message or list of Messages. + + Handles multipart function responses by returning a list of + ChatCompletionToolMessage objects if multiple function_response parts exist. Args: content: The content to convert. Returns: - The litellm Message. + A litellm Message, a list of litellm Messages. """ - if content.parts and content.parts[0].function_response: - return ChatCompletionToolMessage( - role="tool", - tool_call_id=content.parts[0].function_response.id, - content=_safe_json_serialize( - content.parts[0].function_response.response - ), - ) + tool_messages = [] + for part in content.parts: + if part.function_response: + tool_messages.append( + ChatCompletionToolMessage( + role="tool", + tool_call_id=part.function_response.id, + content=_safe_json_serialize(part.function_response.response), + ) + ) + if tool_messages: + return tool_messages if len(tool_messages) > 1 else tool_messages[0] + # Handle user or assistant messages role = _to_litellm_role(content.role) + message_content = _get_content(content.parts) or None if role == "user": - return ChatCompletionUserMessage( - role="user", content=_get_content(content.parts) - ) - else: + return ChatCompletionUserMessage(role="user", content=message_content) + else: # assistant/model + tool_calls = [] + content_present = False + for part in content.parts: + if part.function_call: + tool_calls.append( + ChatCompletionMessageToolCall( + type="function", + id=part.function_call.id, + function=Function( + name=part.function_call.name, + arguments=part.function_call.args, + ), + ) + ) + elif part.text or part.inline_data: + content_present = True - tool_calls = [ - ChatCompletionMessageToolCall( - type="function", - id=part.function_call.id, - function=Function( - name=part.function_call.name, - arguments=part.function_call.args, - ), - ) - for part in content.parts - if part.function_call - ] + final_content = message_content if content_present else None return ChatCompletionAssistantMessage( role=role, - content=_get_content(content.parts) or None, + content=final_content, tool_calls=tool_calls or None, ) @@ -437,10 +449,13 @@ def _get_completion_inputs( Returns: The litellm inputs (message list and tool dictionary). """ - messages = [ - _content_to_message_param(content) - for content in llm_request.contents or [] - ] + messages = [] + for content in llm_request.contents or []: + message_param_or_list = _content_to_message_param(content) + if isinstance(message_param_or_list, list): + messages.extend(message_param_or_list) + elif message_param_or_list: # Ensure it's not None before appending + messages.append(message_param_or_list) if llm_request.config.system_instruction: messages.insert( diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 9d9bd71..613ab38 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -515,6 +515,36 @@ def test_content_to_message_param_user_message(): assert message["content"] == "Test prompt" +def test_content_to_message_param_multi_part_function_response(): + part1 = types.Part.from_function_response( + name="function_one", + response={"result": "result_one"}, + ) + part1.function_response.id = "tool_call_1" + + part2 = types.Part.from_function_response( + name="function_two", + response={"value": 123}, + ) + part2.function_response.id = "tool_call_2" + + content = types.Content( + role="tool", + parts=[part1, part2], + ) + messages = _content_to_message_param(content) + assert isinstance(messages, list) + assert len(messages) == 2 + + assert messages[0]["role"] == "tool" + assert messages[0]["tool_call_id"] == "tool_call_1" + assert messages[0]["content"] == '{"result": "result_one"}' + + assert messages[1]["role"] == "tool" + assert messages[1]["tool_call_id"] == "tool_call_2" + assert messages[1]["content"] == '{"value": 123}' + + def test_content_to_message_param_assistant_message(): content = types.Content( role="assistant", parts=[types.Part.from_text(text="Test response")]