Provide inject_session_state as public util method

This is useful for injecting artifacts and session state variable into instruction template typically in instruction providers.

PiperOrigin-RevId: 761595473
This commit is contained in:
Selcuk Gun
2025-05-21 11:14:41 -07:00
committed by Copybara-Service
parent e060344e39
commit c5a0437745
3 changed files with 234 additions and 77 deletions

View File

@@ -26,6 +26,7 @@ from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from ...events.event import Event
from ...sessions.state import State
from ...utils import instructions_utils
from ._base_llm_processor import BaseLlmRequestProcessor
if TYPE_CHECKING:
@@ -60,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
)
si = raw_si
if not bypass_state_injection:
si = await _populate_values(raw_si, invocation_context)
si = await instructions_utils.inject_session_state(
raw_si, ReadonlyContext(invocation_context)
)
llm_request.append_instructions([si])
# Appends agent instructions if set.
@@ -70,7 +73,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
)
si = raw_si
if not bypass_state_injection:
si = await _populate_values(raw_si, invocation_context)
si = await instructions_utils.inject_session_state(
raw_si, ReadonlyContext(invocation_context)
)
llm_request.append_instructions([si])
# Maintain async generator behavior
@@ -79,78 +84,3 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
request_processor = _InstructionsLlmRequestProcessor()
async def _populate_values(
instruction_template: str,
context: InvocationContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc."""
async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
last_end = 0
for match in re.finditer(pattern, string):
result.append(string[last_end : match.start()])
replacement = await repl_async_fn(match)
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await context.artifact_service.load_artifact(
app_name=context.session.app_name,
user_id=context.session.user_id,
session_id=context.session.id,
filename=var_name,
)
if not var_name:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in context.session.state:
return str(context.session.state[var_name])
else:
if optional:
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
def _is_valid_state_name(var_name):
"""Checks if the variable name is a valid state name.
Valid state is either:
- Valid identifier
- <Valid prefix>:<Valid identifier>
All the others will just return as it is.
Args:
var_name: The variable name to check.
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
if len(parts) == 1:
return var_name.isidentifier()
if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
return parts[1].isidentifier()
return False