Support async instruction and global instruction provider

PiperOrigin-RevId: 757808335
This commit is contained in:
Selcuk Gun 2025-05-12 10:02:51 -07:00 committed by Copybara-Service
parent 812485fdfa
commit 4c4cfb74ae
4 changed files with 165 additions and 7 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations from __future__ import annotations
import inspect
import logging import logging
from typing import ( from typing import (
Any, Any,
@ -96,7 +97,9 @@ AfterToolCallback: TypeAlias = Union[
list[_SingleAfterToolCallback], list[_SingleAfterToolCallback],
] ]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] InstructionProvider: TypeAlias = Callable[
[ReadonlyContext], Union[str, Awaitable[str]]
]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset] ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider] ExamplesUnion = Union[list[Example], BaseExampleProvider]
@ -302,7 +305,7 @@ class LlmAgent(BaseAgent):
ancestor_agent = ancestor_agent.parent_agent ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.') raise ValueError(f'No model found for {self.name}.')
def canonical_instruction(self, ctx: ReadonlyContext) -> str: async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct instruction for this agent. """The resolved self.instruction field to construct instruction for this agent.
This method is only for use by Agent Development Kit. This method is only for use by Agent Development Kit.
@ -310,9 +313,12 @@ class LlmAgent(BaseAgent):
if isinstance(self.instruction, str): if isinstance(self.instruction, str):
return self.instruction return self.instruction
else: else:
return self.instruction(ctx) instruction = self.instruction(ctx)
if inspect.isawaitable(instruction):
instruction = await instruction
return instruction
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str: async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct global instruction. """The resolved self.instruction field to construct global instruction.
This method is only for use by Agent Development Kit. This method is only for use by Agent Development Kit.
@ -320,7 +326,10 @@ class LlmAgent(BaseAgent):
if isinstance(self.global_instruction, str): if isinstance(self.global_instruction, str):
return self.global_instruction return self.global_instruction
else: else:
return self.global_instruction(ctx) global_instruction = self.global_instruction(ctx)
if inspect.isawaitable(global_instruction):
global_instruction = await global_instruction
return global_instruction
async def canonical_tools( async def canonical_tools(
self, ctx: ReadonlyContext = None self, ctx: ReadonlyContext = None

View File

@ -53,7 +53,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
if ( if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not empty str ): # not empty str
raw_si = root_agent.canonical_global_instruction( raw_si = await root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context) ReadonlyContext(invocation_context)
) )
si = await _populate_values(raw_si, invocation_context) si = await _populate_values(raw_si, invocation_context)
@ -61,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends agent instructions if set. # Appends agent instructions if set.
if agent.instruction: # not empty str if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context)) raw_si = await agent.canonical_instruction(
ReadonlyContext(invocation_context)
)
si = await _populate_values(raw_si, invocation_context) si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si]) llm_request.append_instructions([si])

View File

@ -92,6 +92,16 @@ def test_canonical_instruction():
assert agent.canonical_instruction(ctx) == 'instruction: state_value' assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_async_canonical_instruction():
async def _instruction_provider(ctx: ReadonlyContext) -> str:
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_canonical_global_instruction_str(): def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction') agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent) ctx = _create_readonly_context(agent)
@ -114,6 +124,21 @@ def test_canonical_global_instruction():
) )
def test_async_canonical_global_instruction():
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
return f'global instruction: {ctx.state["state_var"]}'
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert (
agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture): def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
with caplog.at_level('WARNING'): with caplog.at_level('WARNING'):

View File

@ -92,6 +92,44 @@ async def test_function_system_instruction():
) )
@pytest.mark.asyncio
async def test_async_function_system_instruction():
async def build_function_instruction(
readonly_context: ReadonlyContext,
) -> str:
return (
"This is the function agent instruction for invocation:"
f" {readonly_context.invocation_id}."
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the function agent instruction for invocation: test_id."
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_global_system_instruction(): async def test_global_system_instruction():
sub_agent = Agent( sub_agent = Agent(
@ -128,6 +166,90 @@ async def test_global_system_instruction():
) )
@pytest.mark.asyncio
async def test_function_global_system_instruction():
def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_async_function_global_system_instruction():
async def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
async def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_build_system_instruction_with_namespace(): async def test_build_system_instruction_with_namespace():
request = LlmRequest( request = LlmRequest(