diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 57cb0b7..a62f096 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -38,7 +39,6 @@ from litellm.types.utils import Delta from litellm.types.utils import ModelResponse from litellm.types.utils import StreamingChoices import pytest -import json LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest( contents=[ @@ -1190,74 +1190,70 @@ async def test_generate_content_async_stream_with_usage_metadata( async def test_generate_content_async_multiple_function_calls( mock_completion, lite_llm_instance ): - """Test handling of multiple function calls with different indices in streaming mode. + """Test handling of multiple function calls with different indices in streaming mode. - This test verifies that: - 1. Multiple function calls with different indices are handled correctly - 2. Arguments and names are properly accumulated for each function call - 3. The final response contains all function calls with correct indices - """ - mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM + This test verifies that: + 1. Multiple function calls with different indices are handled correctly + 2. Arguments and names are properly accumulated for each function call + 3. The final response contains all function calls with correct indices + """ + mock_completion.return_value = MULTIPLE_FUNCTION_CALLS_STREAM - llm_request = LlmRequest( - contents=[ - types.Content( - role="user", - parts=[types.Part.from_text(text="Test multiple function calls")], - ) - ], - config=types.GenerateContentConfig( - tools=[ - types.Tool( - function_declarations=[ - types.FunctionDeclaration( - name="function_1", - description="First test function", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "arg": types.Schema(type=types.Type.STRING), - }, - ), - ), - types.FunctionDeclaration( - name="function_2", - description="Second test function", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "arg": types.Schema(type=types.Type.STRING), - }, - ), - ), - ] - ) - ], - ), - ) + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text(text="Test multiple function calls")], + ) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="function_1", + description="First test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + types.FunctionDeclaration( + name="function_2", + description="Second test function", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "arg": types.Schema(type=types.Type.STRING), + }, + ), + ), + ] + ) + ], + ), + ) - responses = [] - async for response in lite_llm_instance.generate_content_async( - llm_request, stream=True - ): - responses.append(response) + responses = [] + async for response in lite_llm_instance.generate_content_async( + llm_request, stream=True + ): + responses.append(response) - # Verify we got the final response with both function calls - assert len(responses) > 0 - final_response = responses[-1] - assert final_response.content.role == "model" - assert len(final_response.content.parts) == 2 + # Verify we got the final response with both function calls + assert len(responses) > 0 + final_response = responses[-1] + assert final_response.content.role == "model" + assert len(final_response.content.parts) == 2 - # Verify first function call - assert final_response.content.parts[0].function_call.name == "function_1" - assert final_response.content.parts[0].function_call.id == "call_1" - assert final_response.content.parts[0].function_call.args == { - "arg": "value1" - } + # Verify first function call + assert final_response.content.parts[0].function_call.name == "function_1" + assert final_response.content.parts[0].function_call.id == "call_1" + assert final_response.content.parts[0].function_call.args == {"arg": "value1"} - # Verify second function call - assert final_response.content.parts[1].function_call.name == "function_2" - assert final_response.content.parts[1].function_call.id == "call_2" - assert final_response.content.parts[1].function_call.args == { - "arg": "value2" - } + # Verify second function call + assert final_response.content.parts[1].function_call.name == "function_2" + assert final_response.content.parts[1].function_call.id == "call_2" + assert final_response.content.parts[1].function_call.args == {"arg": "value2"}