Support chaining for tool callbacks

(before/after) tool callbacks are invoked throughout the provided chain until one callback does not return None. Callbacks can be async and sync.

PiperOrigin-RevId: 756526507
This commit is contained in:
Selcuk Gun
2025-05-08 17:37:30 -07:00
committed by Copybara-Service
parent 0299020cc4
commit 2cbbf88135
5 changed files with 282 additions and 17 deletions

View File

@@ -117,6 +117,31 @@ def before_agent_cb3(callback_context):
print('@before_agent_cb3')
def before_tool_cb1(tool, args, tool_context):
print('@before_tool_cb1')
def before_tool_cb2(tool, args, tool_context):
print('@before_tool_cb2')
def before_tool_cb3(tool, args, tool_context):
print('@before_tool_cb3')
def after_tool_cb1(tool, args, tool_context, tool_response):
print('@after_tool_cb1')
def after_tool_cb2(tool, args, tool_context, tool_response):
print('@after_tool_cb2')
return {'test': 'after_tool_cb2', 'response': tool_response}
def after_tool_cb3(tool, args, tool_context, tool_response):
print('@after_tool_cb3')
root_agent = Agent(
model='gemini-2.0-flash-exp',
name='data_processing_agent',
@@ -166,4 +191,6 @@ root_agent = Agent(
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
before_model_callback=before_model_callback,
after_model_callback=after_model_callback,
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
)

View File

@@ -83,7 +83,7 @@ async def main():
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
await run_prompt(session_11, 'Roll a die again.')
await run_prompt(session_11, 'Roll a die again with 100 sides.')
await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes')
print(
@@ -130,7 +130,7 @@ def main_sync():
print('------------------------------------')
run_prompt(session_11, 'Hi')
run_prompt(session_11, 'Roll a die with 100 sides.')
run_prompt(session_11, 'Roll a die again.')
run_prompt(session_11, 'Roll a die again with 100 sides.')
run_prompt(session_11, 'What numbers did I got?')
end_time = time.time()
print('------------------------------------')