ADK changes

PiperOrigin-RevId: 755201925
This commit is contained in:
Shangjie Chen
2025-05-05 21:57:51 -07:00
committed by Copybara-Service
parent 6dec235c13
commit 905c20dad6
12 changed files with 86 additions and 69 deletions

View File

@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
if not invocation_context.agent.code_executor:
return
for event in _run_pre_processor(invocation_context, llm_request):
async for event in _run_pre_processor(invocation_context, llm_request):
yield event
# Convert the code execution parts to text parts.
@@ -159,10 +159,10 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
response_processor = _CodeExecutionResponseProcessor()
def _run_pre_processor(
async def _run_pre_processor(
invocation_context: InvocationContext,
llm_request: LlmRequest,
) -> Generator[Event, None, None]:
) -> AsyncGenerator[Event, None]:
"""Pre-process the user message by adding the user message to the Colab notebook."""
from ...agents.llm_agent import LlmAgent
@@ -242,7 +242,7 @@ def _run_pre_processor(
code_executor_context.add_processed_file_names([file.name])
# Emit the execution result, and add it to the LLM request.
execution_result_event = _post_process_code_execution_result(
execution_result_event = await _post_process_code_execution_result(
invocation_context, code_executor_context, code_execution_result
)
yield execution_result_event
@@ -375,7 +375,7 @@ def _get_or_set_execution_id(
return execution_id
def _post_process_code_execution_result(
async def _post_process_code_execution_result(
invocation_context: InvocationContext,
code_executor_context: CodeExecutorContext,
code_execution_result: CodeExecutionResult,
@@ -406,7 +406,7 @@ def _post_process_code_execution_result(
# Handle output files.
for output_file in code_execution_result.output_files:
version = invocation_context.artifact_service.save_artifact(
version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
user_id=invocation_context.user_id,
session_id=invocation_context.session.id,

View File

@@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
si = _populate_values(raw_si, invocation_context)
si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])
# Appends agent instructions if set.
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context)
si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])
# Maintain async generator behavior
@@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
request_processor = _InstructionsLlmRequestProcessor()
def _populate_values(
async def _populate_values(
instruction_template: str,
context: InvocationContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc."""
def _replace_match(match) -> str:
async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
last_end = 0
for match in re.finditer(pattern, string):
result.append(string[last_end : match.start()])
replacement = await repl_async_fn(match)
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False
if var_name.endswith('?'):
@@ -89,7 +100,7 @@ def _populate_values(
var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = context.artifact_service.load_artifact(
artifact = await context.artifact_service.load_artifact(
app_name=context.session.app_name,
user_id=context.session.user_id,
session_id=context.session.id,
@@ -109,7 +120,7 @@ def _populate_values(
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
def _is_valid_state_name(var_name):