diff --git a/pyproject.toml b/pyproject.toml index f49482d..2ed9d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool "google-cloud-speech>=2.30.0", # For Audo Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service - "google-genai>=1.11.0", # Google GenAI SDK + "google-genai>=1.12.1", # Google GenAI SDK "graphviz>=0.20.2", # Graphviz for graph rendering "mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset "opentelemetry-api>=1.31.0", # OpenTelemetry diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e121750..f19ae0f 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -65,6 +65,9 @@ class RunConfig(BaseModel): output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None """Output transcription for live agents with audio response.""" + input_audio_transcription: Optional[types.AudioTranscriptionConfig] = None + """Input transcription for live agents with audio input from user.""" + max_llm_calls: int = 500 """ A limit on the total number of llm calls for a given run. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 188f3a5..31904e3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -190,6 +190,16 @@ class BaseLlmFlow(ABC): llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: """Receive data from model and process events using BaseLlmConnection.""" + def get_author(llm_response): + """Get the author of the event. + + When the model returns transcription, the author is "user". Otherwise, the author is the agent. + """ + if llm_response and llm_response.content and llm_response.content.role == "user": + return "user" + else: + return invocation_context.agent.name + assert invocation_context.live_request_queue try: while True: @@ -197,7 +207,7 @@ class BaseLlmFlow(ABC): model_response_event = Event( id=Event.new_id(), invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, + author=get_author(llm_response), ) async for event in self._postprocess_live( invocation_context, diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 278b4cf..d48c8cd 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -62,6 +62,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor): llm_request.live_connect_config.output_audio_transcription = ( invocation_context.run_config.output_audio_transcription ) + llm_request.live_connect_config.input_audio_transcription = ( + invocation_context.run_config.input_audio_transcription + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 30f1fb2..4018975 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -145,7 +145,20 @@ class GeminiLlmConnection(BaseLlmConnection): yield self.__build_full_text_response(text) text = '' yield llm_response - + if ( + message.server_content.input_transcription + and message.server_content.input_transcription.text + ): + user_text = message.server_content.input_transcription.text + parts = [ + types.Part.from_text( + text=user_text, + ) + ] + llm_response = LlmResponse( + content=types.Content(role='user', parts=parts) + ) + yield llm_response if ( message.server_content.output_transcription and message.server_content.output_transcription.text