diff --git a/src/google/adk/flows/llm_flows/audio_transcriber.py b/src/google/adk/flows/llm_flows/audio_transcriber.py index 6709bb5..a64ab9c 100644 --- a/src/google/adk/flows/llm_flows/audio_transcriber.py +++ b/src/google/adk/flows/llm_flows/audio_transcriber.py @@ -25,8 +25,9 @@ if TYPE_CHECKING: class AudioTranscriber: """Transcribes audio using Google Cloud Speech-to-Text.""" - def __init__(self): - self.client = speech.SpeechClient() + def __init__(self, init_client=False): + if init_client: + self.client = speech.SpeechClient() def transcribe_file( self, invocation_context: InvocationContext @@ -84,7 +85,7 @@ class AudioTranscriber: # Step2: transcription for speaker, data in bundled_audio: - if speaker == 'user': + if isinstance(data, genai_types.Blob): audio = speech.RecognitionAudio(content=data) config = speech.RecognitionConfig( diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index d5ab713..6b7caef 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -89,7 +89,12 @@ class BaseLlmFlow(ABC): if invocation_context.transcription_cache: from . import audio_transcriber - audio_transcriber = audio_transcriber.AudioTranscriber() + audio_transcriber = audio_transcriber.AudioTranscriber( + init_client=True + if invocation_context.run_config.input_audio_transcription + is None + else False + ) contents = audio_transcriber.transcribe_file(invocation_context) logger.debug('Sending history to model: %s', contents) await llm_connection.send_history(contents) @@ -177,9 +182,12 @@ class BaseLlmFlow(ABC): # Cache audio data here for transcription if not invocation_context.transcription_cache: invocation_context.transcription_cache = [] - invocation_context.transcription_cache.append( - TranscriptionEntry(role='user', data=live_request.blob) - ) + if not invocation_context.run_config.input_audio_transcription: + # if the live model's input transcription is not enabled, then + # we use our onwn audio transcriber to achieve that. + invocation_context.transcription_cache.append( + TranscriptionEntry(role='user', data=live_request.blob) + ) await llm_connection.send_realtime(live_request.blob) if live_request.content: await llm_connection.send_content(live_request.content) @@ -193,11 +201,14 @@ class BaseLlmFlow(ABC): ) -> AsyncGenerator[Event, None]: """Receive data from model and process events using BaseLlmConnection.""" - def get_author(llm_response): + def get_author_for_event(llm_response): """Get the author of the event. When the model returns transcription, the author is "user". Otherwise, the - author is the agent. + author is the agent name(not 'model'). + + Args: + llm_response: The LLM response from the LLM call. """ if ( llm_response @@ -215,7 +226,7 @@ class BaseLlmFlow(ABC): model_response_event = Event( id=Event.new_id(), invocation_id=invocation_context.invocation_id, - author=get_author(llm_response), + author=get_author_for_event(llm_response), ) async for event in self._postprocess_live( invocation_context, @@ -229,10 +240,17 @@ class BaseLlmFlow(ABC): and event.content.parts[0].text and not event.partial ): + # This can be either user data or transcription data. + # when output transcription enabled, it will contain model's + # transcription. + # when input transcription enabled, it will contain user + # transcription. if not invocation_context.transcription_cache: invocation_context.transcription_cache = [] invocation_context.transcription_cache.append( - TranscriptionEntry(role='model', data=event.content) + TranscriptionEntry( + role=event.content.role, data=event.content + ) ) yield event # Give opportunity for other tasks to run. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 28cf655..1209e03 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -456,6 +456,9 @@ class Runner: run_config.output_audio_transcription = ( types.AudioTranscriptionConfig() ) + if not run_config.input_audio_transcription: + # need this input transcription for agent transferring in live mode. + run_config.input_audio_transcription = types.AudioTranscriptionConfig() return self._new_invocation_context( session, live_request_queue=live_request_queue,