mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-25 14:37:45 -06:00
ADK changes
PiperOrigin-RevId: 755201925
This commit is contained in:
committed by
Copybara-Service
parent
6dec235c13
commit
905c20dad6
@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
||||
if not invocation_context.agent.code_executor:
|
||||
return
|
||||
|
||||
for event in _run_pre_processor(invocation_context, llm_request):
|
||||
async for event in _run_pre_processor(invocation_context, llm_request):
|
||||
yield event
|
||||
|
||||
# Convert the code execution parts to text parts.
|
||||
@@ -159,10 +159,10 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
||||
response_processor = _CodeExecutionResponseProcessor()
|
||||
|
||||
|
||||
def _run_pre_processor(
|
||||
async def _run_pre_processor(
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> Generator[Event, None, None]:
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
@@ -242,7 +242,7 @@ def _run_pre_processor(
|
||||
code_executor_context.add_processed_file_names([file.name])
|
||||
|
||||
# Emit the execution result, and add it to the LLM request.
|
||||
execution_result_event = _post_process_code_execution_result(
|
||||
execution_result_event = await _post_process_code_execution_result(
|
||||
invocation_context, code_executor_context, code_execution_result
|
||||
)
|
||||
yield execution_result_event
|
||||
@@ -375,7 +375,7 @@ def _get_or_set_execution_id(
|
||||
return execution_id
|
||||
|
||||
|
||||
def _post_process_code_execution_result(
|
||||
async def _post_process_code_execution_result(
|
||||
invocation_context: InvocationContext,
|
||||
code_executor_context: CodeExecutorContext,
|
||||
code_execution_result: CodeExecutionResult,
|
||||
@@ -406,7 +406,7 @@ def _post_process_code_execution_result(
|
||||
|
||||
# Handle output files.
|
||||
for output_file in code_execution_result.output_files:
|
||||
version = invocation_context.artifact_service.save_artifact(
|
||||
version = await invocation_context.artifact_service.save_artifact(
|
||||
app_name=invocation_context.app_name,
|
||||
user_id=invocation_context.user_id,
|
||||
session_id=invocation_context.session.id,
|
||||
|
||||
@@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
raw_si = root_agent.canonical_global_instruction(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
si = await _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Appends agent instructions if set.
|
||||
if agent.instruction: # not empty str
|
||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
si = await _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Maintain async generator behavior
|
||||
@@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
request_processor = _InstructionsLlmRequestProcessor()
|
||||
|
||||
|
||||
def _populate_values(
|
||||
async def _populate_values(
|
||||
instruction_template: str,
|
||||
context: InvocationContext,
|
||||
) -> str:
|
||||
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
||||
|
||||
def _replace_match(match) -> str:
|
||||
async def _async_sub(pattern, repl_async_fn, string) -> str:
|
||||
result = []
|
||||
last_end = 0
|
||||
for match in re.finditer(pattern, string):
|
||||
result.append(string[last_end : match.start()])
|
||||
replacement = await repl_async_fn(match)
|
||||
result.append(replacement)
|
||||
last_end = match.end()
|
||||
result.append(string[last_end:])
|
||||
return ''.join(result)
|
||||
|
||||
async def _replace_match(match) -> str:
|
||||
var_name = match.group().lstrip('{').rstrip('}').strip()
|
||||
optional = False
|
||||
if var_name.endswith('?'):
|
||||
@@ -89,7 +100,7 @@ def _populate_values(
|
||||
var_name = var_name.removeprefix('artifact.')
|
||||
if context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
artifact = context.artifact_service.load_artifact(
|
||||
artifact = await context.artifact_service.load_artifact(
|
||||
app_name=context.session.app_name,
|
||||
user_id=context.session.user_id,
|
||||
session_id=context.session.id,
|
||||
@@ -109,7 +120,7 @@ def _populate_values(
|
||||
else:
|
||||
raise KeyError(f'Context variable not found: `{var_name}`.')
|
||||
|
||||
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||
|
||||
|
||||
def _is_valid_state_name(var_name):
|
||||
|
||||
Reference in New Issue
Block a user