mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
fix: separate thinking from text parts in streaming mode
Copybara import of the project: -- 79962881ca1c17eb6d7bd9dcf31a44df93c9badd by Almas Akchabayev <almas.akchabayev@gmail.com>: fix: separate thinking from text parts in streaming mode COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/777 from almeynman:separate-thinking-and-text-parts-in-stream-mode b63dcc7fd0fc3973888dcbb9d4cc7e7e0a66e7f7 PiperOrigin-RevId: 764561932
This commit is contained in:
parent
60ceea72bd
commit
795605a37e
@ -97,6 +97,7 @@ class Gemini(BaseLlm):
|
||||
config=llm_request.config,
|
||||
)
|
||||
response = None
|
||||
thought_text = ''
|
||||
text = ''
|
||||
usage_metadata = None
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||
@ -113,32 +114,43 @@ class Gemini(BaseLlm):
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
text += llm_response.content.parts[0].text
|
||||
part0 = llm_response.content.parts[0]
|
||||
if part0.thought:
|
||||
thought_text += part0.text
|
||||
else:
|
||||
text += part0.text
|
||||
llm_response.partial = True
|
||||
elif text and (
|
||||
elif (thought_text or text) and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
usage_metadata=usage_metadata,
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
thought_text = ''
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
text
|
||||
(text or thought_text)
|
||||
and response
|
||||
and response.candidates
|
||||
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||
):
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
|
||||
|
@ -206,6 +206,80 @@ async def test_generate_content_async_stream(gemini_llm, llm_request):
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
|
||||
gemini_llm, llm_request
|
||||
):
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
class MockAsyncIterator:
|
||||
def __init__(self, seq):
|
||||
self._iter = iter(seq)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self._iter)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
response1 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part(text="Think1", thought=True)],
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
response2 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part(text="Think2", thought=True)],
|
||||
),
|
||||
finish_reason=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
response3 = types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[Part.from_text(text="Answer.")],
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def mock_coro():
|
||||
return MockAsyncIterator([response1, response2, response3])
|
||||
|
||||
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
]
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].partial is True
|
||||
assert responses[1].partial is True
|
||||
assert responses[2].partial is True
|
||||
assert responses[3].content.parts[0].text == "Think1Think2"
|
||||
assert responses[3].content.parts[0].thought is True
|
||||
assert responses[3].content.parts[1].text == "Answer."
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(gemini_llm, llm_request):
|
||||
# Create a mock connection
|
||||
|
Loading…
Reference in New Issue
Block a user