From 9e9dfa7472de563f005d807b79214b8099b21623 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Fri, 16 May 2025 12:08:24 -0700 Subject: [PATCH] Prevent session state injection for provider supplied instructions When the user provides instruction provider, we assume that they will inject the session state parameters if needed. This assumption allows users to return code snippets in the instruction provider without any template replacement. PiperOrigin-RevId: 759705471 --- src/google/adk/agents/llm_agent.py | 34 +++++++++++++++---- .../adk/flows/llm_flows/instructions.py | 16 ++++++--- .../unittests/agents/test_llm_agent_fields.py | 32 ++++++++++++----- .../flows/llm_flows/test_instructions.py | 15 ++++++-- 4 files changed, 74 insertions(+), 23 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 4e31f46..4a419ef 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -307,31 +307,53 @@ class LlmAgent(BaseAgent): ancestor_agent = ancestor_agent.parent_agent raise ValueError(f'No model found for {self.name}.') - async def canonical_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: """The resolved self.instruction field to construct instruction for this agent. This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. """ if isinstance(self.instruction, str): - return self.instruction + return self.instruction, False else: instruction = self.instruction(ctx) if inspect.isawaitable(instruction): instruction = await instruction - return instruction + return instruction, True - async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_global_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: """The resolved self.instruction field to construct global instruction. This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.global_instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. """ if isinstance(self.global_instruction, str): - return self.global_instruction + return self.global_instruction, False else: global_instruction = self.global_instruction(ctx) if inspect.isawaitable(global_instruction): global_instruction = await global_instruction - return global_instruction + return global_instruction, True async def canonical_tools( self, ctx: ReadonlyContext = None diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index d2ae683..2956d6f 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -53,18 +53,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): if ( isinstance(root_agent, LlmAgent) and root_agent.global_instruction ): # not empty str - raw_si = await root_agent.canonical_global_instruction( - ReadonlyContext(invocation_context) + raw_si, bypass_state_injection = ( + await root_agent.canonical_global_instruction( + ReadonlyContext(invocation_context) + ) ) - si = await _populate_values(raw_si, invocation_context) + si = raw_si + if not bypass_state_injection: + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Appends agent instructions if set. if agent.instruction: # not empty str - raw_si = await agent.canonical_instruction( + raw_si, bypass_state_injection = await agent.canonical_instruction( ReadonlyContext(invocation_context) ) - si = await _populate_values(raw_si, invocation_context) + si = raw_si + if not bypass_state_injection: + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Maintain async generator behavior diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 106e20d..287ef3b 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -79,8 +79,11 @@ async def test_canonical_instruction_str(): agent = LlmAgent(name='test_agent', instruction='instruction') ctx = await _create_readonly_context(agent) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction' + assert not bypass_state_injection async def test_canonical_instruction(): @@ -92,8 +95,11 @@ async def test_canonical_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction: state_value' + assert bypass_state_injection async def test_async_canonical_instruction(): @@ -105,16 +111,22 @@ async def test_async_canonical_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction: state_value' + assert bypass_state_injection async def test_canonical_global_instruction_str(): agent = LlmAgent(name='test_agent', global_instruction='global instruction') ctx = await _create_readonly_context(agent) - canonical_instruction = await agent.canonical_global_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_global_instruction(ctx) + ) assert canonical_instruction == 'global instruction' + assert not bypass_state_injection async def test_canonical_global_instruction(): @@ -128,9 +140,11 @@ async def test_canonical_global_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_global_instruction = await agent.canonical_global_instruction(ctx) + canonical_global_instruction, bypass_state_injection = ( + await agent.canonical_global_instruction(ctx) + ) assert canonical_global_instruction == 'global instruction: state_value' - + assert bypass_state_injection async def test_async_canonical_global_instruction(): async def _global_instruction_provider(ctx: ReadonlyContext) -> str: @@ -142,11 +156,11 @@ async def test_async_canonical_global_instruction(): ctx = await _create_readonly_context( agent, state={'state_var': 'state_value'} ) - - assert ( + canonical_global_instruction, bypass_state_injection = ( await agent.canonical_global_instruction(ctx) - == 'global instruction: state_value' ) + assert canonical_global_instruction == 'global instruction: state_value' + assert bypass_state_injection def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): diff --git a/tests/unittests/flows/llm_flows/test_instructions.py b/tests/unittests/flows/llm_flows/test_instructions.py index 73117d4..d6e2234 100644 --- a/tests/unittests/flows/llm_flows/test_instructions.py +++ b/tests/unittests/flows/llm_flows/test_instructions.py @@ -61,6 +61,8 @@ async def test_function_system_instruction(): def build_function_instruction(readonly_context: ReadonlyContext) -> str: return ( "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" f" {readonly_context.invocation_id}." ) @@ -88,10 +90,12 @@ async def test_function_system_instruction(): pass assert request.config.system_instruction == ( - "This is the function agent instruction for invocation: test_id." + "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" + " test_id." ) - @pytest.mark.asyncio async def test_async_function_system_instruction(): async def build_function_instruction( @@ -99,6 +103,8 @@ async def test_async_function_system_instruction(): ) -> str: return ( "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" f" {readonly_context.invocation_id}." ) @@ -126,7 +132,10 @@ async def test_async_function_system_instruction(): pass assert request.config.system_instruction == ( - "This is the function agent instruction for invocation: test_id." + "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" + " test_id." )