From 39f78dc28f86199417d7e7dc27f3d6bf3bba06d5 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Tue, 13 May 2025 11:12:35 -0700 Subject: [PATCH] feat(live): Support native(model-side) audio transcription for agent transferring in live mode. The old implementation: 1. We only started transcription at the beginning of agent transferring. 2. The transcription service we used is not as good / fast as the model/native transcription. In the current implementation, the live agent will rely on the llm's transcription, instead of our transcription when llm support audio transcription in the input. And in that case, the live agent won't use our own audio transcriber. This reduces the latency from 5secs to 2 secs during agent transferring. It also improves the transcription quality. When the llm doesn't support audio transcription, we still use our audio transcriber to transcribe audio input. PiperOrigin-RevId: 758296647 --- .../adk/flows/llm_flows/audio_transcriber.py | 7 ++-- .../adk/flows/llm_flows/base_llm_flow.py | 34 ++++++++++++++----- src/google/adk/runners.py | 3 ++ 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/google/adk/flows/llm_flows/audio_transcriber.py b/src/google/adk/flows/llm_flows/audio_transcriber.py index 6709bb5..a64ab9c 100644 --- a/src/google/adk/flows/llm_flows/audio_transcriber.py +++ b/src/google/adk/flows/llm_flows/audio_transcriber.py @@ -25,8 +25,9 @@ if TYPE_CHECKING: class AudioTranscriber: """Transcribes audio using Google Cloud Speech-to-Text.""" - def __init__(self): - self.client = speech.SpeechClient() + def __init__(self, init_client=False): + if init_client: + self.client = speech.SpeechClient() def transcribe_file( self, invocation_context: InvocationContext @@ -84,7 +85,7 @@ class AudioTranscriber: # Step2: transcription for speaker, data in bundled_audio: - if speaker == 'user': + if isinstance(data, genai_types.Blob): audio = speech.RecognitionAudio(content=data) config = speech.RecognitionConfig( diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index d5ab713..6b7caef 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -89,7 +89,12 @@ class BaseLlmFlow(ABC): if invocation_context.transcription_cache: from . import audio_transcriber - audio_transcriber = audio_transcriber.AudioTranscriber() + audio_transcriber = audio_transcriber.AudioTranscriber( + init_client=True + if invocation_context.run_config.input_audio_transcription + is None + else False + ) contents = audio_transcriber.transcribe_file(invocation_context) logger.debug('Sending history to model: %s', contents) await llm_connection.send_history(contents) @@ -177,9 +182,12 @@ class BaseLlmFlow(ABC): # Cache audio data here for transcription if not invocation_context.transcription_cache: invocation_context.transcription_cache = [] - invocation_context.transcription_cache.append( - TranscriptionEntry(role='user', data=live_request.blob) - ) + if not invocation_context.run_config.input_audio_transcription: + # if the live model's input transcription is not enabled, then + # we use our onwn audio transcriber to achieve that. + invocation_context.transcription_cache.append( + TranscriptionEntry(role='user', data=live_request.blob) + ) await llm_connection.send_realtime(live_request.blob) if live_request.content: await llm_connection.send_content(live_request.content) @@ -193,11 +201,14 @@ class BaseLlmFlow(ABC): ) -> AsyncGenerator[Event, None]: """Receive data from model and process events using BaseLlmConnection.""" - def get_author(llm_response): + def get_author_for_event(llm_response): """Get the author of the event. When the model returns transcription, the author is "user". Otherwise, the - author is the agent. + author is the agent name(not 'model'). + + Args: + llm_response: The LLM response from the LLM call. """ if ( llm_response @@ -215,7 +226,7 @@ class BaseLlmFlow(ABC): model_response_event = Event( id=Event.new_id(), invocation_id=invocation_context.invocation_id, - author=get_author(llm_response), + author=get_author_for_event(llm_response), ) async for event in self._postprocess_live( invocation_context, @@ -229,10 +240,17 @@ class BaseLlmFlow(ABC): and event.content.parts[0].text and not event.partial ): + # This can be either user data or transcription data. + # when output transcription enabled, it will contain model's + # transcription. + # when input transcription enabled, it will contain user + # transcription. if not invocation_context.transcription_cache: invocation_context.transcription_cache = [] invocation_context.transcription_cache.append( - TranscriptionEntry(role='model', data=event.content) + TranscriptionEntry( + role=event.content.role, data=event.content + ) ) yield event # Give opportunity for other tasks to run. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 28cf655..1209e03 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -456,6 +456,9 @@ class Runner: run_config.output_audio_transcription = ( types.AudioTranscriptionConfig() ) + if not run_config.input_audio_transcription: + # need this input transcription for agent transferring in live mode. + run_config.input_audio_transcription = types.AudioTranscriptionConfig() return self._new_invocation_context( session, live_request_queue=live_request_queue,