diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 0076c6a..302b1bc 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -14,6 +14,7 @@ from __future__ import annotations +import inspect import logging from typing import ( Any, @@ -96,7 +97,9 @@ AfterToolCallback: TypeAlias = Union[ list[_SingleAfterToolCallback], ] -InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] +InstructionProvider: TypeAlias = Callable[ + [ReadonlyContext], Union[str, Awaitable[str]] +] ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset] ExamplesUnion = Union[list[Example], BaseExampleProvider] @@ -302,7 +305,7 @@ class LlmAgent(BaseAgent): ancestor_agent = ancestor_agent.parent_agent raise ValueError(f'No model found for {self.name}.') - def canonical_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_instruction(self, ctx: ReadonlyContext) -> str: """The resolved self.instruction field to construct instruction for this agent. This method is only for use by Agent Development Kit. @@ -310,9 +313,12 @@ class LlmAgent(BaseAgent): if isinstance(self.instruction, str): return self.instruction else: - return self.instruction(ctx) + instruction = self.instruction(ctx) + if inspect.isawaitable(instruction): + instruction = await instruction + return instruction - def canonical_global_instruction(self, ctx: ReadonlyContext) -> str: + async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str: """The resolved self.instruction field to construct global instruction. This method is only for use by Agent Development Kit. @@ -320,7 +326,10 @@ class LlmAgent(BaseAgent): if isinstance(self.global_instruction, str): return self.global_instruction else: - return self.global_instruction(ctx) + global_instruction = self.global_instruction(ctx) + if inspect.isawaitable(global_instruction): + global_instruction = await global_instruction + return global_instruction async def canonical_tools( self, ctx: ReadonlyContext = None diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index 041c867..d2ae683 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -53,7 +53,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): if ( isinstance(root_agent, LlmAgent) and root_agent.global_instruction ): # not empty str - raw_si = root_agent.canonical_global_instruction( + raw_si = await root_agent.canonical_global_instruction( ReadonlyContext(invocation_context) ) si = await _populate_values(raw_si, invocation_context) @@ -61,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): # Appends agent instructions if set. if agent.instruction: # not empty str - raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context)) + raw_si = await agent.canonical_instruction( + ReadonlyContext(invocation_context) + ) si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 42ad5ca..a442381 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -92,6 +92,16 @@ def test_canonical_instruction(): assert agent.canonical_instruction(ctx) == 'instruction: state_value' +def test_async_canonical_instruction(): + async def _instruction_provider(ctx: ReadonlyContext) -> str: + return f'instruction: {ctx.state["state_var"]}' + + agent = LlmAgent(name='test_agent', instruction=_instruction_provider) + ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + + assert agent.canonical_instruction(ctx) == 'instruction: state_value' + + def test_canonical_global_instruction_str(): agent = LlmAgent(name='test_agent', global_instruction='global instruction') ctx = _create_readonly_context(agent) @@ -114,6 +124,21 @@ def test_canonical_global_instruction(): ) +def test_async_canonical_global_instruction(): + async def _global_instruction_provider(ctx: ReadonlyContext) -> str: + return f'global instruction: {ctx.state["state_var"]}' + + agent = LlmAgent( + name='test_agent', global_instruction=_global_instruction_provider + ) + ctx = _create_readonly_context(agent, state={'state_var': 'state_value'}) + + assert ( + agent.canonical_global_instruction(ctx) + == 'global instruction: state_value' + ) + + def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): with caplog.at_level('WARNING'): diff --git a/tests/unittests/flows/llm_flows/test_instructions.py b/tests/unittests/flows/llm_flows/test_instructions.py index edc7902..0d2ac5e 100644 --- a/tests/unittests/flows/llm_flows/test_instructions.py +++ b/tests/unittests/flows/llm_flows/test_instructions.py @@ -92,6 +92,44 @@ async def test_function_system_instruction(): ) +@pytest.mark.asyncio +async def test_async_function_system_instruction(): + async def build_function_instruction( + readonly_context: ReadonlyContext, + ) -> str: + return ( + "This is the function agent instruction for invocation:" + f" {readonly_context.invocation_id}." + ) + + request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(system_instruction=""), + ) + agent = Agent( + model="gemini-1.5-flash", + name="agent", + instruction=build_function_instruction, + ) + invocation_context = utils.create_invocation_context(agent=agent) + invocation_context.session = Session( + app_name="test_app", + user_id="test_user", + id="test_id", + state={"customerId": "1234567890", "customer_int": 30}, + ) + + async for _ in instructions.request_processor.run_async( + invocation_context, + request, + ): + pass + + assert request.config.system_instruction == ( + "This is the function agent instruction for invocation: test_id." + ) + + @pytest.mark.asyncio async def test_global_system_instruction(): sub_agent = Agent( @@ -128,6 +166,90 @@ async def test_global_system_instruction(): ) +@pytest.mark.asyncio +async def test_function_global_system_instruction(): + def sub_agent_si(readonly_context: ReadonlyContext) -> str: + return "This is the sub agent instruction." + + def root_agent_gi(readonly_context: ReadonlyContext) -> str: + return "This is the global instruction." + + sub_agent = Agent( + model="gemini-1.5-flash", + name="sub_agent", + instruction=sub_agent_si, + ) + root_agent = Agent( + model="gemini-1.5-flash", + name="root_agent", + global_instruction=root_agent_gi, + sub_agents=[sub_agent], + ) + request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(system_instruction=""), + ) + invocation_context = utils.create_invocation_context(agent=sub_agent) + invocation_context.session = Session( + app_name="test_app", + user_id="test_user", + id="test_id", + state={"customerId": "1234567890", "customer_int": 30}, + ) + + async for _ in instructions.request_processor.run_async( + invocation_context, + request, + ): + pass + + assert request.config.system_instruction == ( + "This is the global instruction.\n\nThis is the sub agent instruction." + ) + + +@pytest.mark.asyncio +async def test_async_function_global_system_instruction(): + async def sub_agent_si(readonly_context: ReadonlyContext) -> str: + return "This is the sub agent instruction." + + async def root_agent_gi(readonly_context: ReadonlyContext) -> str: + return "This is the global instruction." + + sub_agent = Agent( + model="gemini-1.5-flash", + name="sub_agent", + instruction=sub_agent_si, + ) + root_agent = Agent( + model="gemini-1.5-flash", + name="root_agent", + global_instruction=root_agent_gi, + sub_agents=[sub_agent], + ) + request = LlmRequest( + model="gemini-1.5-flash", + config=types.GenerateContentConfig(system_instruction=""), + ) + invocation_context = utils.create_invocation_context(agent=sub_agent) + invocation_context.session = Session( + app_name="test_app", + user_id="test_user", + id="test_id", + state={"customerId": "1234567890", "customer_int": 30}, + ) + + async for _ in instructions.request_processor.run_async( + invocation_context, + request, + ): + pass + + assert request.config.system_instruction == ( + "This is the global instruction.\n\nThis is the sub agent instruction." + ) + + @pytest.mark.asyncio async def test_build_system_instruction_with_namespace(): request = LlmRequest(