mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
fix: separate thinking from text parts in streaming mode
Copybara import of the project: -- 79962881ca1c17eb6d7bd9dcf31a44df93c9badd by Almas Akchabayev <almas.akchabayev@gmail.com>: fix: separate thinking from text parts in streaming mode COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/777 from almeynman:separate-thinking-and-text-parts-in-stream-mode b63dcc7fd0fc3973888dcbb9d4cc7e7e0a66e7f7 PiperOrigin-RevId: 764561932
This commit is contained in:
parent
60ceea72bd
commit
795605a37e
@ -97,6 +97,7 @@ class Gemini(BaseLlm):
|
|||||||
config=llm_request.config,
|
config=llm_request.config,
|
||||||
)
|
)
|
||||||
response = None
|
response = None
|
||||||
|
thought_text = ''
|
||||||
text = ''
|
text = ''
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||||
@ -113,32 +114,43 @@ class Gemini(BaseLlm):
|
|||||||
and llm_response.content.parts
|
and llm_response.content.parts
|
||||||
and llm_response.content.parts[0].text
|
and llm_response.content.parts[0].text
|
||||||
):
|
):
|
||||||
text += llm_response.content.parts[0].text
|
part0 = llm_response.content.parts[0]
|
||||||
|
if part0.thought:
|
||||||
|
thought_text += part0.text
|
||||||
|
else:
|
||||||
|
text += part0.text
|
||||||
llm_response.partial = True
|
llm_response.partial = True
|
||||||
elif text and (
|
elif (thought_text or text) and (
|
||||||
not llm_response.content
|
not llm_response.content
|
||||||
or not llm_response.content.parts
|
or not llm_response.content.parts
|
||||||
# don't yield the merged text event when receiving audio data
|
# don't yield the merged text event when receiving audio data
|
||||||
or not llm_response.content.parts[0].inline_data
|
or not llm_response.content.parts[0].inline_data
|
||||||
):
|
):
|
||||||
|
parts = []
|
||||||
|
if thought_text:
|
||||||
|
parts.append(types.Part(text=thought_text, thought=True))
|
||||||
|
if text:
|
||||||
|
parts.append(types.Part.from_text(text=text))
|
||||||
yield LlmResponse(
|
yield LlmResponse(
|
||||||
content=types.ModelContent(
|
content=types.ModelContent(parts=parts),
|
||||||
parts=[types.Part.from_text(text=text)],
|
usage_metadata=llm_response.usage_metadata,
|
||||||
),
|
|
||||||
usage_metadata=usage_metadata,
|
|
||||||
)
|
)
|
||||||
|
thought_text = ''
|
||||||
text = ''
|
text = ''
|
||||||
yield llm_response
|
yield llm_response
|
||||||
if (
|
if (
|
||||||
text
|
(text or thought_text)
|
||||||
and response
|
and response
|
||||||
and response.candidates
|
and response.candidates
|
||||||
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||||
):
|
):
|
||||||
|
parts = []
|
||||||
|
if thought_text:
|
||||||
|
parts.append(types.Part(text=thought_text, thought=True))
|
||||||
|
if text:
|
||||||
|
parts.append(types.Part.from_text(text=text))
|
||||||
yield LlmResponse(
|
yield LlmResponse(
|
||||||
content=types.ModelContent(
|
content=types.ModelContent(parts=parts),
|
||||||
parts=[types.Part.from_text(text=text)],
|
|
||||||
),
|
|
||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -206,6 +206,80 @@ async def test_generate_content_async_stream(gemini_llm, llm_request):
|
|||||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
|
||||||
|
gemini_llm, llm_request
|
||||||
|
):
|
||||||
|
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||||
|
class MockAsyncIterator:
|
||||||
|
def __init__(self, seq):
|
||||||
|
self._iter = iter(seq)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
return next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
response1 = types.GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[Part(text="Think1", thought=True)],
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
response2 = types.GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[Part(text="Think2", thought=True)],
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
response3 = types.GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[Part.from_text(text="Answer.")],
|
||||||
|
),
|
||||||
|
finish_reason=types.FinishReason.STOP,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_coro():
|
||||||
|
return MockAsyncIterator([response1, response2, response3])
|
||||||
|
|
||||||
|
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
resp
|
||||||
|
async for resp in gemini_llm.generate_content_async(
|
||||||
|
llm_request, stream=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert responses[0].partial is True
|
||||||
|
assert responses[1].partial is True
|
||||||
|
assert responses[2].partial is True
|
||||||
|
assert responses[3].content.parts[0].text == "Think1Think2"
|
||||||
|
assert responses[3].content.parts[0].thought is True
|
||||||
|
assert responses[3].content.parts[1].text == "Answer."
|
||||||
|
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect(gemini_llm, llm_request):
|
async def test_connect(gemini_llm, llm_request):
|
||||||
# Create a mock connection
|
# Create a mock connection
|
||||||
|
Loading…
Reference in New Issue
Block a user