diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 4e31f46..4a419ef 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -307,31 +307,53 @@ class LlmAgent(BaseAgent): ancestor_agent = ancestor_agent.parent_agent raise ValueError(f'No model found for {self.name}.') - async def canonical_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: """The resolved self.instruction field to construct instruction for this agent. This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. """ if isinstance(self.instruction, str): - return self.instruction + return self.instruction, False else: instruction = self.instruction(ctx) if inspect.isawaitable(instruction): instruction = await instruction - return instruction + return instruction, True - async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_global_instruction( + self, ctx: ReadonlyContext + ) -> tuple[str, bool]: """The resolved self.instruction field to construct global instruction. This method is only for use by Agent Development Kit. + + Args: + ctx: The context to retrieve the session state. + + Returns: + A tuple of (instruction, bypass_state_injection). + instruction: The resolved self.global_instruction field. + bypass_state_injection: Whether the instruction is based on + InstructionProvider. """ if isinstance(self.global_instruction, str): - return self.global_instruction + return self.global_instruction, False else: global_instruction = self.global_instruction(ctx) if inspect.isawaitable(global_instruction): global_instruction = await global_instruction - return global_instruction + return global_instruction, True async def canonical_tools( self, ctx: ReadonlyContext = None diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index d2ae683..2956d6f 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -53,18 +53,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): if ( isinstance(root_agent, LlmAgent) and root_agent.global_instruction ): # not empty str - raw_si = await root_agent.canonical_global_instruction( - ReadonlyContext(invocation_context) + raw_si, bypass_state_injection = ( + await root_agent.canonical_global_instruction( + ReadonlyContext(invocation_context) + ) ) - si = await _populate_values(raw_si, invocation_context) + si = raw_si + if not bypass_state_injection: + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Appends agent instructions if set. if agent.instruction: # not empty str - raw_si = await agent.canonical_instruction( + raw_si, bypass_state_injection = await agent.canonical_instruction( ReadonlyContext(invocation_context) ) - si = await _populate_values(raw_si, invocation_context) + si = raw_si + if not bypass_state_injection: + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Maintain async generator behavior diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 106e20d..287ef3b 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -79,8 +79,11 @@ async def test_canonical_instruction_str(): agent = LlmAgent(name='test_agent', instruction='instruction') ctx = await _create_readonly_context(agent) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction' + assert not bypass_state_injection async def test_canonical_instruction(): @@ -92,8 +95,11 @@ async def test_canonical_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction: state_value' + assert bypass_state_injection async def test_async_canonical_instruction(): @@ -105,16 +111,22 @@ async def test_async_canonical_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_instruction = await agent.canonical_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_instruction(ctx) + ) assert canonical_instruction == 'instruction: state_value' + assert bypass_state_injection async def test_canonical_global_instruction_str(): agent = LlmAgent(name='test_agent', global_instruction='global instruction') ctx = await _create_readonly_context(agent) - canonical_instruction = await agent.canonical_global_instruction(ctx) + canonical_instruction, bypass_state_injection = ( + await agent.canonical_global_instruction(ctx) + ) assert canonical_instruction == 'global instruction' + assert not bypass_state_injection async def test_canonical_global_instruction(): @@ -128,9 +140,11 @@ async def test_canonical_global_instruction(): agent, state={'state_var': 'state_value'} ) - canonical_global_instruction = await agent.canonical_global_instruction(ctx) + canonical_global_instruction, bypass_state_injection = ( + await agent.canonical_global_instruction(ctx) + ) assert canonical_global_instruction == 'global instruction: state_value' - + assert bypass_state_injection async def test_async_canonical_global_instruction(): async def _global_instruction_provider(ctx: ReadonlyContext) -> str: @@ -142,11 +156,11 @@ async def test_async_canonical_global_instruction(): ctx = await _create_readonly_context( agent, state={'state_var': 'state_value'} ) - - assert ( + canonical_global_instruction, bypass_state_injection = ( await agent.canonical_global_instruction(ctx) - == 'global instruction: state_value' ) + assert canonical_global_instruction == 'global instruction: state_value' + assert bypass_state_injection def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): diff --git a/tests/unittests/flows/llm_flows/test_instructions.py b/tests/unittests/flows/llm_flows/test_instructions.py index 73117d4..d6e2234 100644 --- a/tests/unittests/flows/llm_flows/test_instructions.py +++ b/tests/unittests/flows/llm_flows/test_instructions.py @@ -61,6 +61,8 @@ async def test_function_system_instruction(): def build_function_instruction(readonly_context: ReadonlyContext) -> str: return ( "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" f" {readonly_context.invocation_id}." ) @@ -88,10 +90,12 @@ async def test_function_system_instruction(): pass assert request.config.system_instruction == ( - "This is the function agent instruction for invocation: test_id." + "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" + " test_id." ) - @pytest.mark.asyncio async def test_async_function_system_instruction(): async def build_function_instruction( @@ -99,6 +103,8 @@ async def test_async_function_system_instruction(): ) -> str: return ( "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" f" {readonly_context.invocation_id}." ) @@ -126,7 +132,10 @@ async def test_async_function_system_instruction(): pass assert request.config.system_instruction == ( - "This is the function agent instruction for invocation: test_id." + "This is the function agent instruction for invocation:" + " provider template intact { customerId }" + " provider template intact { customer_int }" + " test_id." )