Support async instruction and global instruction provider

PiperOrigin-RevId: 757808335
This commit is contained in:
Selcuk Gun 2025-05-12 10:02:51 -07:00 committed by Copybara-Service
parent 812485fdfa
commit 4c4cfb74ae
4 changed files with 165 additions and 7 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import inspect
import logging
from typing import (
Any,
@ -96,7 +97,9 @@ AfterToolCallback: TypeAlias = Union[
list[_SingleAfterToolCallback],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
InstructionProvider: TypeAlias = Callable[
[ReadonlyContext], Union[str, Awaitable[str]]
]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
@ -302,7 +305,7 @@ class LlmAgent(BaseAgent):
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct instruction for this agent.
This method is only for use by Agent Development Kit.
@ -310,9 +313,12 @@ class LlmAgent(BaseAgent):
if isinstance(self.instruction, str):
return self.instruction
else:
return self.instruction(ctx)
instruction = self.instruction(ctx)
if inspect.isawaitable(instruction):
instruction = await instruction
return instruction
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct global instruction.
This method is only for use by Agent Development Kit.
@ -320,7 +326,10 @@ class LlmAgent(BaseAgent):
if isinstance(self.global_instruction, str):
return self.global_instruction
else:
return self.global_instruction(ctx)
global_instruction = self.global_instruction(ctx)
if inspect.isawaitable(global_instruction):
global_instruction = await global_instruction
return global_instruction
async def canonical_tools(
self, ctx: ReadonlyContext = None

View File

@ -53,7 +53,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not empty str
raw_si = root_agent.canonical_global_instruction(
raw_si = await root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
si = await _populate_values(raw_si, invocation_context)
@ -61,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends agent instructions if set.
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
raw_si = await agent.canonical_instruction(
ReadonlyContext(invocation_context)
)
si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])

View File

@ -92,6 +92,16 @@ def test_canonical_instruction():
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_async_canonical_instruction():
async def _instruction_provider(ctx: ReadonlyContext) -> str:
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent)
@ -114,6 +124,21 @@ def test_canonical_global_instruction():
)
def test_async_canonical_global_instruction():
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
return f'global instruction: {ctx.state["state_var"]}'
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert (
agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
with caplog.at_level('WARNING'):

View File

@ -92,6 +92,44 @@ async def test_function_system_instruction():
)
@pytest.mark.asyncio
async def test_async_function_system_instruction():
async def build_function_instruction(
readonly_context: ReadonlyContext,
) -> str:
return (
"This is the function agent instruction for invocation:"
f" {readonly_context.invocation_id}."
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the function agent instruction for invocation: test_id."
)
@pytest.mark.asyncio
async def test_global_system_instruction():
sub_agent = Agent(
@ -128,6 +166,90 @@ async def test_global_system_instruction():
)
@pytest.mark.asyncio
async def test_function_global_system_instruction():
def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_async_function_global_system_instruction():
async def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
async def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_build_system_instruction_with_namespace():
request = LlmRequest(