Support chaining for model callbacks

(before/after) model callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 755565583
This commit is contained in:
Selcuk Gun
2025-05-06 16:15:33 -07:00
committed by Copybara-Service
parent 794a70edcd
commit e4317c9eb7
5 changed files with 731 additions and 23 deletions

View File

@@ -193,8 +193,9 @@ class BaseLlmFlow(ABC):
"""Receive data from model and process events using BaseLlmConnection."""
def get_author(llm_response):
"""Get the author of the event.
When the model returns transcription, the author is "user". Otherwise, the author is the agent.
When the model returns transcription, the author is "user". Otherwise, the
author is the agent.
"""
if llm_response and llm_response.content and llm_response.content.role == "user":
return "user"
@@ -509,20 +510,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent):
return
if not agent.before_model_callback:
if not agent.canonical_before_model_callbacks:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
before_model_callback_content = agent.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
return before_model_callback_content
for callback in agent.canonical_before_model_callbacks:
before_model_callback_content = callback(
callback_context=callback_context, llm_request=llm_request
)
if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
if before_model_callback_content:
return before_model_callback_content
async def _handle_after_model_callback(
self,
@@ -536,20 +538,21 @@ class BaseLlmFlow(ABC):
if not isinstance(agent, LlmAgent):
return
if not agent.after_model_callback:
if not agent.canonical_after_model_callbacks:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
after_model_callback_content = agent.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
return after_model_callback_content
for callback in agent.canonical_after_model_callbacks:
after_model_callback_content = callback(
callback_context=callback_context, llm_response=llm_response
)
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
if after_model_callback_content:
return after_model_callback_content
def _finalize_model_response_event(
self,