Support async instruction and global instruction provider

PiperOrigin-RevId: 757808335
This commit is contained in:
Selcuk Gun
2025-05-12 10:02:51 -07:00
committed by Copybara-Service
parent 812485fdfa
commit 4c4cfb74ae
4 changed files with 165 additions and 7 deletions

View File

@@ -14,6 +14,7 @@
from __future__ import annotations
import inspect
import logging
from typing import (
Any,
@@ -96,7 +97,9 @@ AfterToolCallback: TypeAlias = Union[
list[_SingleAfterToolCallback],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
InstructionProvider: TypeAlias = Callable[
[ReadonlyContext], Union[str, Awaitable[str]]
]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
@@ -302,7 +305,7 @@ class LlmAgent(BaseAgent):
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct instruction for this agent.
This method is only for use by Agent Development Kit.
@@ -310,9 +313,12 @@ class LlmAgent(BaseAgent):
if isinstance(self.instruction, str):
return self.instruction
else:
return self.instruction(ctx)
instruction = self.instruction(ctx)
if inspect.isawaitable(instruction):
instruction = await instruction
return instruction
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct global instruction.
This method is only for use by Agent Development Kit.
@@ -320,7 +326,10 @@ class LlmAgent(BaseAgent):
if isinstance(self.global_instruction, str):
return self.global_instruction
else:
return self.global_instruction(ctx)
global_instruction = self.global_instruction(ctx)
if inspect.isawaitable(global_instruction):
global_instruction = await global_instruction
return global_instruction
async def canonical_tools(
self, ctx: ReadonlyContext = None

View File

@@ -53,7 +53,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not empty str
raw_si = root_agent.canonical_global_instruction(
raw_si = await root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
si = await _populate_values(raw_si, invocation_context)
@@ -61,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends agent instructions if set.
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
raw_si = await agent.canonical_instruction(
ReadonlyContext(invocation_context)
)
si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])