From 795605a37e1141e37d86c9b3fa484a3a03e7e9a6 Mon Sep 17 00:00:00 2001 From: Almas Akchabayev Date: Wed, 28 May 2025 22:10:39 -0700 Subject: [PATCH] fix: separate thinking from text parts in streaming mode Copybara import of the project: -- 79962881ca1c17eb6d7bd9dcf31a44df93c9badd by Almas Akchabayev : fix: separate thinking from text parts in streaming mode COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/777 from almeynman:separate-thinking-and-text-parts-in-stream-mode b63dcc7fd0fc3973888dcbb9d4cc7e7e0a66e7f7 PiperOrigin-RevId: 764561932 --- src/google/adk/models/google_llm.py | 32 +++++++--- tests/unittests/models/test_google_llm.py | 74 +++++++++++++++++++++++ 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 20b747c..42dc964 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -97,6 +97,7 @@ class Gemini(BaseLlm): config=llm_request.config, ) response = None + thought_text = '' text = '' usage_metadata = None # for sse, similar as bidi (see receive method in gemini_llm_connecton.py), @@ -113,32 +114,43 @@ class Gemini(BaseLlm): and llm_response.content.parts and llm_response.content.parts[0].text ): - text += llm_response.content.parts[0].text + part0 = llm_response.content.parts[0] + if part0.thought: + thought_text += part0.text + else: + text += part0.text llm_response.partial = True - elif text and ( + elif (thought_text or text) and ( not llm_response.content or not llm_response.content.parts # don't yield the merged text event when receiving audio data or not llm_response.content.parts[0].inline_data ): + parts = [] + if thought_text: + parts.append(types.Part(text=thought_text, thought=True)) + if text: + parts.append(types.Part.from_text(text=text)) yield LlmResponse( - content=types.ModelContent( - parts=[types.Part.from_text(text=text)], - ), - usage_metadata=usage_metadata, + content=types.ModelContent(parts=parts), + usage_metadata=llm_response.usage_metadata, ) + thought_text = '' text = '' yield llm_response if ( - text + (text or thought_text) and response and response.candidates and response.candidates[0].finish_reason == types.FinishReason.STOP ): + parts = [] + if thought_text: + parts.append(types.Part(text=thought_text, thought=True)) + if text: + parts.append(types.Part.from_text(text=text)) yield LlmResponse( - content=types.ModelContent( - parts=[types.Part.from_text(text=text)], - ), + content=types.ModelContent(parts=parts), usage_metadata=usage_metadata, ) diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 07c22bb..3b3e570 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -206,6 +206,80 @@ async def test_generate_content_async_stream(gemini_llm, llm_request): mock_client.aio.models.generate_content_stream.assert_called_once() +@pytest.mark.asyncio +async def test_generate_content_async_stream_preserves_thinking_and_text_parts( + gemini_llm, llm_request +): + with mock.patch.object(gemini_llm, "api_client") as mock_client: + class MockAsyncIterator: + def __init__(self, seq): + self._iter = iter(seq) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._iter) + except StopIteration: + raise StopAsyncIteration + + response1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part(text="Think1", thought=True)], + ), + finish_reason=None, + ) + ] + ) + response2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part(text="Think2", thought=True)], + ), + finish_reason=None, + ) + ] + ) + response3 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part.from_text(text="Answer.")], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + async def mock_coro(): + return MockAsyncIterator([response1, response2, response3]) + + mock_client.aio.models.generate_content_stream.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=True + ) + ] + + assert len(responses) == 4 + assert responses[0].partial is True + assert responses[1].partial is True + assert responses[2].partial is True + assert responses[3].content.parts[0].text == "Think1Think2" + assert responses[3].content.parts[0].thought is True + assert responses[3].content.parts[1].text == "Answer." + mock_client.aio.models.generate_content_stream.assert_called_once() + + @pytest.mark.asyncio async def test_connect(gemini_llm, llm_request): # Create a mock connection