Support async agent and model callbacks

PiperOrigin-RevId: 755542756
This commit is contained in:
Selcuk Gun
2025-05-06 15:13:39 -07:00
committed by Copybara-Service
parent f96cdc675c
commit 794a70edcd
25 changed files with 371 additions and 117 deletions

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -33,16 +33,34 @@ def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_before_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
async def _async_before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
pass
async def _async_after_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
@@ -51,6 +69,14 @@ def _after_agent_callback_append_agent_reply(
)
async def _async_after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
parts=[types.Part(text='Agent reply from after agent callback.')]
)
class _IncompleteAgent(BaseAgent):
pass
@@ -158,6 +184,34 @@ async def test_run_async_before_agent_callback_noop(
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
) -> Union[types.Content, None]:
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
_ = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
_, kwargs = spy_before_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
@@ -185,6 +239,33 @@ async def test_run_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
spy_run_async_impl.assert_not_called()
assert len(events) == 1
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
@@ -211,6 +292,32 @@ async def test_run_async_after_agent_callback_noop(
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_after_agent_callback.assert_called_once()
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
@@ -236,6 +343,31 @@ async def test_run_async_after_agent_callback_append_reply(
)
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
assert len(events) == 2
assert events[1].author == agent.name
assert (
events[1].content.parts[0].text
== 'Agent reply from after agent callback.'
)
@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')

View File

@@ -56,10 +56,44 @@ class MockAfterModelCallback(BaseModel):
)
class MockAsyncBeforeModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncAfterModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio
async def test_before_model_callback():
responses = ['model_response']
@@ -98,26 +132,6 @@ async def test_before_model_callback_noop():
]
@pytest.mark.asyncio
async def test_before_model_callback_end():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback',
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_model_callback'),
]
@pytest.mark.asyncio
async def test_after_model_callback():
responses = ['model_response']
@@ -136,3 +150,61 @@ async def test_after_model_callback():
) == [
('root_agent', 'after_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockAsyncBeforeModelCallback(
mock_response='async_before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_async_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAsyncAfterModelCallback(
mock_response='async_after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_after_model_callback'),
]

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -74,9 +74,12 @@ class TestConnectionsClient:
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"data": "test"}
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
with (
mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
response = client._execute_api_call("https://test.url")
assert response.json() == {"data": "test"}
requests.get.assert_called_once_with(
@@ -121,9 +124,12 @@ class TestConnectionsClient:
f"HTTP error {status_code}: {response_text}"
)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
with (
mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
with pytest.raises(
ValueError, match="Invalid request. Please check the provided"
):
@@ -140,9 +146,12 @@ class TestConnectionsClient:
"Internal Server Error"
)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch("requests.get", return_value=mock_response):
with (
mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
with pytest.raises(ValueError, match="Request error: "):
client._execute_api_call("https://test.url")
@@ -151,10 +160,13 @@ class TestConnectionsClient:
):
credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
), mock.patch(
"requests.get", side_effect=Exception("Something went wrong")
with (
mock.patch.object(
client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch(
"requests.get", side_effect=Exception("Something went wrong")
),
):
with pytest.raises(
Exception, match="An unexpected error occurred: Something went wrong"
@@ -539,10 +551,13 @@ class TestConnectionsClient:
mock_creds.token = "sa_token"
mock_creds.expired = False
with mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
), mock.patch.object(mock_creds, "refresh", return_value=None):
with (
mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
),
mock.patch.object(mock_creds, "refresh", return_value=None),
):
token = client._get_access_token()
assert token == "sa_token"
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
@@ -555,10 +570,13 @@ class TestConnectionsClient:
self, project, location, connection_name, mock_credentials
):
client = ConnectionsClient(project, location, connection_name, None)
with mock.patch(
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
), mock.patch.object(mock_credentials, "refresh", return_value=None):
with (
mock.patch(
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
),
mock.patch.object(mock_credentials, "refresh", return_value=None),
):
token = client._get_access_token()
assert token == "test_token"

View File

@@ -114,11 +114,14 @@ class TestIntegrationClient:
mock_response.status_code = 200
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
with (
mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient(
project=project,
location=location,
@@ -202,11 +205,14 @@ class TestIntegrationClient:
f"HTTP error {status_code}: {response_text}"
)
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
with (
mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient(
project=project,
location=location,
@@ -243,11 +249,14 @@ class TestIntegrationClient:
"Internal Server Error"
)
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch("requests.post", return_value=mock_response):
with (
mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient(
project=project,
location=location,
@@ -270,12 +279,15 @@ class TestIntegrationClient:
mock_credentials,
mock_connections_client,
):
with mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
), mock.patch(
"requests.post", side_effect=Exception("Something went wrong")
with (
mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
),
mock.patch(
"requests.post", side_effect=Exception("Something went wrong")
),
):
client = IntegrationClient(
project=project,
@@ -486,10 +498,13 @@ class TestIntegrationClient:
mock_creds.token = "sa_token"
mock_creds.expired = False
with mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
), mock.patch.object(mock_creds, "refresh", return_value=None):
with (
mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=mock_creds,
),
mock.patch.object(mock_creds, "refresh", return_value=None),
):
client = IntegrationClient(
project=project,
location=location,
@@ -518,10 +533,13 @@ class TestIntegrationClient:
mock_credentials,
):
mock_credentials.expired = False
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
), mock.patch.object(mock_credentials, "refresh", return_value=None):
with (
mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(mock_credentials, "test_project_id"),
),
mock.patch.object(mock_credentials, "refresh", return_value=None),
):
client = IntegrationClient(
project=project,
location=location,
@@ -538,12 +556,15 @@ class TestIntegrationClient:
def test_get_access_token_no_valid_credentials(
self, project, location, integration_name, trigger_name, connection_name
):
with mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(None, None),
), mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=None,
with (
mock.patch(
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
return_value=(None, None),
),
mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=None,
),
):
client = IntegrationClient(
project=project,
@@ -587,9 +608,12 @@ class TestIntegrationClient:
service_account_json=None,
)
client.credential_cache = mock_credentials # Simulate a cached credential
with mock.patch("google.auth.default") as mock_default, mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info"
) as mock_sa:
with (
mock.patch("google.auth.default") as mock_default,
mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info"
) as mock_sa,
):
token = client._get_access_token()
assert token == "cached_token"
mock_default.assert_not_called()

View File

@@ -236,8 +236,7 @@ def test_initialization_with_connection_and_actions(
)
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
tool_name, tool_instructions
)
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
@@ -390,6 +389,5 @@ def test_initialization_with_connection_details(
tool_instructions=tool_instructions,
)
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
tool_name, tool_instructions
)

View File

@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.