Support async instruction and global instruction provider

PiperOrigin-RevId: 757808335
This commit is contained in:
Selcuk Gun
2025-05-12 10:02:51 -07:00
committed by Copybara-Service
parent 812485fdfa
commit 4c4cfb74ae
4 changed files with 165 additions and 7 deletions

View File

@@ -92,6 +92,16 @@ def test_canonical_instruction():
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_async_canonical_instruction():
async def _instruction_provider(ctx: ReadonlyContext) -> str:
return f'instruction: {ctx.state["state_var"]}'
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
def test_canonical_global_instruction_str():
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
ctx = _create_readonly_context(agent)
@@ -114,6 +124,21 @@ def test_canonical_global_instruction():
)
def test_async_canonical_global_instruction():
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
return f'global instruction: {ctx.state["state_var"]}'
agent = LlmAgent(
name='test_agent', global_instruction=_global_instruction_provider
)
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
assert (
agent.canonical_global_instruction(ctx)
== 'global instruction: state_value'
)
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
with caplog.at_level('WARNING'):