Support async instruction and global instruction provider

PiperOrigin-RevId: 757808335
This commit is contained in:
Selcuk Gun
2025-05-12 10:02:51 -07:00
committed by Copybara-Service
parent 812485fdfa
commit 4c4cfb74ae
4 changed files with 165 additions and 7 deletions

View File

@@ -92,6 +92,44 @@ async def test_function_system_instruction():
)
@pytest.mark.asyncio
async def test_async_function_system_instruction():
async def build_function_instruction(
readonly_context: ReadonlyContext,
) -> str:
return (
"This is the function agent instruction for invocation:"
f" {readonly_context.invocation_id}."
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
agent = Agent(
model="gemini-1.5-flash",
name="agent",
instruction=build_function_instruction,
)
invocation_context = utils.create_invocation_context(agent=agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the function agent instruction for invocation: test_id."
)
@pytest.mark.asyncio
async def test_global_system_instruction():
sub_agent = Agent(
@@ -128,6 +166,90 @@ async def test_global_system_instruction():
)
@pytest.mark.asyncio
async def test_function_global_system_instruction():
def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_async_function_global_system_instruction():
async def sub_agent_si(readonly_context: ReadonlyContext) -> str:
return "This is the sub agent instruction."
async def root_agent_gi(readonly_context: ReadonlyContext) -> str:
return "This is the global instruction."
sub_agent = Agent(
model="gemini-1.5-flash",
name="sub_agent",
instruction=sub_agent_si,
)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
global_instruction=root_agent_gi,
sub_agents=[sub_agent],
)
request = LlmRequest(
model="gemini-1.5-flash",
config=types.GenerateContentConfig(system_instruction=""),
)
invocation_context = utils.create_invocation_context(agent=sub_agent)
invocation_context.session = Session(
app_name="test_app",
user_id="test_user",
id="test_id",
state={"customerId": "1234567890", "customer_int": 30},
)
async for _ in instructions.request_processor.run_async(
invocation_context,
request,
):
pass
assert request.config.system_instruction == (
"This is the global instruction.\n\nThis is the sub agent instruction."
)
@pytest.mark.asyncio
async def test_build_system_instruction_with_namespace():
request = LlmRequest(