Support async agent and model callbacks

PiperOrigin-RevId: 755542756
This commit is contained in:
Selcuk Gun
2025-05-06 15:13:39 -07:00
committed by Copybara-Service
parent f96cdc675c
commit 794a70edcd
25 changed files with 371 additions and 117 deletions

View File

@@ -33,16 +33,34 @@ def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_before_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
async def _async_before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_after_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
@@ -51,6 +69,14 @@ def _after_agent_callback_append_agent_reply(
)
async def _async_after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
parts=[types.Part(text='Agent reply from after agent callback.')]
)
class _IncompleteAgent(BaseAgent):
pass
@@ -158,6 +184,34 @@ async def test_run_async_before_agent_callback_noop(
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
) -> Union[types.Content, None]:
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
_ = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
_, kwargs = spy_before_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
@@ -185,6 +239,33 @@ async def test_run_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
spy_run_async_impl.assert_not_called()
assert len(events) == 1
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
@@ -211,6 +292,32 @@ async def test_run_async_after_agent_callback_noop(
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_after_agent_callback.assert_called_once()
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
@@ -236,6 +343,31 @@ async def test_run_async_after_agent_callback_append_reply(
)
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
assert len(events) == 2
assert events[1].author == agent.name
assert (
events[1].content.parts[0].text
== 'Agent reply from after agent callback.'
)
@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')