Support async agent and model callbacks

PiperOrigin-RevId: 755542756
This commit is contained in:
Selcuk Gun 2025-05-06 15:13:39 -07:00 committed by Copybara-Service
parent f96cdc675c
commit 794a70edcd
25 changed files with 371 additions and 117 deletions

View File

@ -14,7 +14,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any import inspect
from typing import Any, Awaitable, Union
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Callable from typing import Callable
from typing import final from typing import final
@ -37,10 +38,15 @@ if TYPE_CHECKING:
tracer = trace.get_tracer('gcp.vertex.agent') tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]] BeforeAgentCallback = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
]
AfterAgentCallback = Callable[
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]] [CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
]
class BaseAgent(BaseModel): class BaseAgent(BaseModel):
@ -119,7 +125,7 @@ class BaseAgent(BaseModel):
with tracer.start_as_current_span(f'agent_run [{self.name}]'): with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context) ctx = self._create_invocation_context(parent_context)
if event := self.__handle_before_agent_callback(ctx): if event := await self.__handle_before_agent_callback(ctx):
yield event yield event
if ctx.end_invocation: if ctx.end_invocation:
return return
@ -130,7 +136,7 @@ class BaseAgent(BaseModel):
if ctx.end_invocation: if ctx.end_invocation:
return return
if event := self.__handle_after_agent_callback(ctx): if event := await self.__handle_after_agent_callback(ctx):
yield event yield event
@final @final
@ -230,7 +236,7 @@ class BaseAgent(BaseModel):
invocation_context.branch = f'{parent_context.branch}.{self.name}' invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context return invocation_context
def __handle_before_agent_callback( async def __handle_before_agent_callback(
self, ctx: InvocationContext self, ctx: InvocationContext
) -> Optional[Event]: ) -> Optional[Event]:
"""Runs the before_agent_callback if it exists. """Runs the before_agent_callback if it exists.
@ -248,6 +254,9 @@ class BaseAgent(BaseModel):
callback_context=callback_context callback_context=callback_context
) )
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content: if before_agent_callback_content:
ret_event = Event( ret_event = Event(
invocation_id=ctx.invocation_id, invocation_id=ctx.invocation_id,
@ -269,7 +278,7 @@ class BaseAgent(BaseModel):
return ret_event return ret_event
def __handle_after_agent_callback( async def __handle_after_agent_callback(
self, invocation_context: InvocationContext self, invocation_context: InvocationContext
) -> Optional[Event]: ) -> Optional[Event]:
"""Runs the after_agent_callback if it exists. """Runs the after_agent_callback if it exists.
@ -287,6 +296,9 @@ class BaseAgent(BaseModel):
callback_context=callback_context callback_context=callback_context
) )
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content or callback_context.state.has_delta(): if after_agent_callback_content or callback_context.state.has_delta():
ret_event = Event( ret_event = Event(
invocation_id=invocation_context.invocation_id, invocation_id=invocation_context.invocation_id,

View File

@ -49,11 +49,12 @@ logger = logging.getLogger(__name__)
BeforeModelCallback: TypeAlias = Callable[ BeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest], Optional[LlmResponse] [CallbackContext, LlmRequest],
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
] ]
AfterModelCallback: TypeAlias = Callable[ AfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse], [CallbackContext, LlmResponse],
Optional[LlmResponse], Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
] ]
BeforeToolCallback: TypeAlias = Callable[ BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext], [BaseTool, dict[str, Any], ToolContext],

View File

@ -16,6 +16,7 @@ from __future__ import annotations
from abc import ABC from abc import ABC
import asyncio import asyncio
import inspect
import logging import logging
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import cast from typing import cast
@ -199,7 +200,7 @@ class BaseLlmFlow(ABC):
return "user" return "user"
else: else:
return invocation_context.agent.name return invocation_context.agent.name
assert invocation_context.live_request_queue assert invocation_context.live_request_queue
try: try:
while True: while True:
@ -447,7 +448,7 @@ class BaseLlmFlow(ABC):
model_response_event: Event, model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]: ) -> AsyncGenerator[LlmResponse, None]:
# Runs before_model_callback if it exists. # Runs before_model_callback if it exists.
if response := self._handle_before_model_callback( if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event invocation_context, llm_request, model_response_event
): ):
yield response yield response
@ -460,7 +461,7 @@ class BaseLlmFlow(ABC):
invocation_context.live_request_queue = LiveRequestQueue() invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context): async for llm_response in self.run_live(invocation_context):
# Runs after_model_callback if it exists. # Runs after_model_callback if it exists.
if altered_llm_response := self._handle_after_model_callback( if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event invocation_context, llm_response, model_response_event
): ):
llm_response = altered_llm_response llm_response = altered_llm_response
@ -489,14 +490,14 @@ class BaseLlmFlow(ABC):
llm_response, llm_response,
) )
# Runs after_model_callback if it exists. # Runs after_model_callback if it exists.
if altered_llm_response := self._handle_after_model_callback( if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event invocation_context, llm_response, model_response_event
): ):
llm_response = altered_llm_response llm_response = altered_llm_response
yield llm_response yield llm_response
def _handle_before_model_callback( async def _handle_before_model_callback(
self, self,
invocation_context: InvocationContext, invocation_context: InvocationContext,
llm_request: LlmRequest, llm_request: LlmRequest,
@ -514,11 +515,16 @@ class BaseLlmFlow(ABC):
callback_context = CallbackContext( callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions invocation_context, event_actions=model_response_event.actions
) )
return agent.before_model_callback( before_model_callback_content = agent.before_model_callback(
callback_context=callback_context, llm_request=llm_request callback_context=callback_context, llm_request=llm_request
) )
def _handle_after_model_callback( if inspect.isawaitable(before_model_callback_content):
before_model_callback_content = await before_model_callback_content
return before_model_callback_content
async def _handle_after_model_callback(
self, self,
invocation_context: InvocationContext, invocation_context: InvocationContext,
llm_response: LlmResponse, llm_response: LlmResponse,
@ -536,10 +542,15 @@ class BaseLlmFlow(ABC):
callback_context = CallbackContext( callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions invocation_context, event_actions=model_response_event.actions
) )
return agent.after_model_callback( after_model_callback_content = agent.after_model_callback(
callback_context=callback_context, llm_response=llm_response callback_context=callback_context, llm_response=llm_response
) )
if inspect.isawaitable(after_model_callback_content):
after_model_callback_content = await after_model_callback_content
return after_model_callback_content
def _finalize_model_response_event( def _finalize_model_response_event(
self, self,
llm_request: LlmRequest, llm_request: LlmRequest,

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import agent from . import agent

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -19,6 +19,7 @@ import os
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
import pytest import pytest
def agent_eval_artifacts_in_fixture(): def agent_eval_artifacts_in_fixture():
"""Get all agents from fixture folder.""" """Get all agents from fixture folder."""
agent_eval_artifacts = [] agent_eval_artifacts = []

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -33,16 +33,34 @@ def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
pass pass
async def _async_before_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _before_agent_callback_bypass_agent( def _before_agent_callback_bypass_agent(
callback_context: CallbackContext, callback_context: CallbackContext,
) -> types.Content: ) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')]) return types.Content(parts=[types.Part(text='agent run is bypassed.')])
async def _async_before_agent_callback_bypass_agent(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
def _after_agent_callback_noop(callback_context: CallbackContext) -> None: def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
pass pass
async def _async_after_agent_callback_noop(
callback_context: CallbackContext,
) -> None:
pass
def _after_agent_callback_append_agent_reply( def _after_agent_callback_append_agent_reply(
callback_context: CallbackContext, callback_context: CallbackContext,
) -> types.Content: ) -> types.Content:
@ -51,6 +69,14 @@ def _after_agent_callback_append_agent_reply(
) )
async def _async_after_agent_callback_append_agent_reply(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
parts=[types.Part(text='Agent reply from after agent callback.')]
)
class _IncompleteAgent(BaseAgent): class _IncompleteAgent(BaseAgent):
pass pass
@ -158,6 +184,34 @@ async def test_run_async_before_agent_callback_noop(
spy_run_async_impl.assert_called_once() spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
) -> Union[types.Content, None]:
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
_ = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
_, kwargs = spy_before_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
spy_run_async_impl.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_before_agent_callback_bypass_agent( async def test_run_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
@ -185,6 +239,33 @@ async def test_run_async_before_agent_callback_bypass_agent(
assert events[0].content.parts[0].text == 'agent run is bypassed.' assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio
async def test_run_async_with_async_before_agent_callback_bypass_agent(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
before_agent_callback=_async_before_agent_callback_bypass_agent,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_before_agent_callback.assert_called_once()
spy_run_async_impl.assert_not_called()
assert len(events) == 1
assert events[0].content.parts[0].text == 'agent run is bypassed.'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_after_agent_callback_noop( async def test_run_async_after_agent_callback_noop(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
@ -211,6 +292,32 @@ async def test_run_async_after_agent_callback_noop(
assert len(events) == 1 assert len(events) == 1
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_noop(
request: pytest.FixtureRequest,
mocker: pytest_mock.MockerFixture,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_noop,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
spy_after_agent_callback.assert_called_once()
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_after_agent_callback_append_reply( async def test_run_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
@ -236,6 +343,31 @@ async def test_run_async_after_agent_callback_append_reply(
) )
@pytest.mark.asyncio
async def test_run_async_with_async_after_agent_callback_append_reply(
request: pytest.FixtureRequest,
):
# Arrange
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_async_after_agent_callback_append_agent_reply,
)
parent_ctx = _create_parent_invocation_context(
request.function.__name__, agent
)
# Act
events = [e async for e in agent.run_async(parent_ctx)]
# Assert
assert len(events) == 2
assert events[1].author == agent.name
assert (
events[1].content.parts[0].text
== 'Agent reply from after agent callback.'
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')

View File

@ -56,10 +56,44 @@ class MockAfterModelCallback(BaseModel):
) )
class MockAsyncBeforeModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
class MockAsyncAfterModelCallback(BaseModel):
mock_response: str
async def __call__(
self,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> LlmResponse:
return LlmResponse(
content=utils.ModelContent(
[types.Part.from_text(text=self.mock_response)]
)
)
def noop_callback(**kwargs) -> Optional[LlmResponse]: def noop_callback(**kwargs) -> Optional[LlmResponse]:
pass pass
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_before_model_callback(): async def test_before_model_callback():
responses = ['model_response'] responses = ['model_response']
@ -98,26 +132,6 @@ async def test_before_model_callback_noop():
] ]
@pytest.mark.asyncio
async def test_before_model_callback_end():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockBeforeModelCallback(
mock_response='before_model_callback',
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'before_model_callback'),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_after_model_callback(): async def test_after_model_callback():
responses = ['model_response'] responses = ['model_response']
@ -136,3 +150,61 @@ async def test_after_model_callback():
) == [ ) == [
('root_agent', 'after_model_callback'), ('root_agent', 'after_model_callback'),
] ]
@pytest.mark.asyncio
async def test_async_before_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=MockAsyncBeforeModelCallback(
mock_response='async_before_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_before_model_callback'),
]
@pytest.mark.asyncio
async def test_async_before_model_callback_noop():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_model_callback=async_noop_callback,
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'model_response'),
]
@pytest.mark.asyncio
async def test_async_after_model_callback():
responses = ['model_response']
mock_model = utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
after_model_callback=MockAsyncAfterModelCallback(
mock_response='async_after_model_callback'
),
)
runner = utils.TestInMemoryRunner(agent)
assert utils.simplify_events(
await runner.run_async_with_new_session('test')
) == [
('root_agent', 'async_after_model_callback'),
]

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

View File

@ -74,9 +74,12 @@ class TestConnectionsClient:
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"data": "test"} mock_response.json.return_value = {"data": "test"}
with mock.patch.object( with (
client, "_get_access_token", return_value=mock_credentials.token mock.patch.object(
), mock.patch("requests.get", return_value=mock_response): client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
response = client._execute_api_call("https://test.url") response = client._execute_api_call("https://test.url")
assert response.json() == {"data": "test"} assert response.json() == {"data": "test"}
requests.get.assert_called_once_with( requests.get.assert_called_once_with(
@ -121,9 +124,12 @@ class TestConnectionsClient:
f"HTTP error {status_code}: {response_text}" f"HTTP error {status_code}: {response_text}"
) )
with mock.patch.object( with (
client, "_get_access_token", return_value=mock_credentials.token mock.patch.object(
), mock.patch("requests.get", return_value=mock_response): client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
with pytest.raises( with pytest.raises(
ValueError, match="Invalid request. Please check the provided" ValueError, match="Invalid request. Please check the provided"
): ):
@ -140,9 +146,12 @@ class TestConnectionsClient:
"Internal Server Error" "Internal Server Error"
) )
with mock.patch.object( with (
client, "_get_access_token", return_value=mock_credentials.token mock.patch.object(
), mock.patch("requests.get", return_value=mock_response): client, "_get_access_token", return_value=mock_credentials.token
),
mock.patch("requests.get", return_value=mock_response),
):
with pytest.raises(ValueError, match="Request error: "): with pytest.raises(ValueError, match="Request error: "):
client._execute_api_call("https://test.url") client._execute_api_call("https://test.url")
@ -151,10 +160,13 @@ class TestConnectionsClient:
): ):
credentials = {"email": "test@example.com"} credentials = {"email": "test@example.com"}
client = ConnectionsClient(project, location, connection_name, credentials) client = ConnectionsClient(project, location, connection_name, credentials)
with mock.patch.object( with (
client, "_get_access_token", return_value=mock_credentials.token mock.patch.object(
), mock.patch( client, "_get_access_token", return_value=mock_credentials.token
"requests.get", side_effect=Exception("Something went wrong") ),
mock.patch(
"requests.get", side_effect=Exception("Something went wrong")
),
): ):
with pytest.raises( with pytest.raises(
Exception, match="An unexpected error occurred: Something went wrong" Exception, match="An unexpected error occurred: Something went wrong"
@ -539,10 +551,13 @@ class TestConnectionsClient:
mock_creds.token = "sa_token" mock_creds.token = "sa_token"
mock_creds.expired = False mock_creds.expired = False
with mock.patch( with (
"google.oauth2.service_account.Credentials.from_service_account_info", mock.patch(
return_value=mock_creds, "google.oauth2.service_account.Credentials.from_service_account_info",
), mock.patch.object(mock_creds, "refresh", return_value=None): return_value=mock_creds,
),
mock.patch.object(mock_creds, "refresh", return_value=None),
):
token = client._get_access_token() token = client._get_access_token()
assert token == "sa_token" assert token == "sa_token"
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with( google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
@ -555,10 +570,13 @@ class TestConnectionsClient:
self, project, location, connection_name, mock_credentials self, project, location, connection_name, mock_credentials
): ):
client = ConnectionsClient(project, location, connection_name, None) client = ConnectionsClient(project, location, connection_name, None)
with mock.patch( with (
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential", mock.patch(
return_value=(mock_credentials, "test_project_id"), "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
), mock.patch.object(mock_credentials, "refresh", return_value=None): return_value=(mock_credentials, "test_project_id"),
),
mock.patch.object(mock_credentials, "refresh", return_value=None),
):
token = client._get_access_token() token = client._get_access_token()
assert token == "test_token" assert token == "test_token"

View File

@ -114,11 +114,14 @@ class TestIntegrationClient:
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)} mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
with mock.patch.object( with (
IntegrationClient, mock.patch.object(
"_get_access_token", IntegrationClient,
return_value=mock_credentials.token, "_get_access_token",
), mock.patch("requests.post", return_value=mock_response): return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
location=location, location=location,
@ -202,11 +205,14 @@ class TestIntegrationClient:
f"HTTP error {status_code}: {response_text}" f"HTTP error {status_code}: {response_text}"
) )
with mock.patch.object( with (
IntegrationClient, mock.patch.object(
"_get_access_token", IntegrationClient,
return_value=mock_credentials.token, "_get_access_token",
), mock.patch("requests.post", return_value=mock_response): return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
location=location, location=location,
@ -243,11 +249,14 @@ class TestIntegrationClient:
"Internal Server Error" "Internal Server Error"
) )
with mock.patch.object( with (
IntegrationClient, mock.patch.object(
"_get_access_token", IntegrationClient,
return_value=mock_credentials.token, "_get_access_token",
), mock.patch("requests.post", return_value=mock_response): return_value=mock_credentials.token,
),
mock.patch("requests.post", return_value=mock_response),
):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
location=location, location=location,
@ -270,12 +279,15 @@ class TestIntegrationClient:
mock_credentials, mock_credentials,
mock_connections_client, mock_connections_client,
): ):
with mock.patch.object( with (
IntegrationClient, mock.patch.object(
"_get_access_token", IntegrationClient,
return_value=mock_credentials.token, "_get_access_token",
), mock.patch( return_value=mock_credentials.token,
"requests.post", side_effect=Exception("Something went wrong") ),
mock.patch(
"requests.post", side_effect=Exception("Something went wrong")
),
): ):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
@ -486,10 +498,13 @@ class TestIntegrationClient:
mock_creds.token = "sa_token" mock_creds.token = "sa_token"
mock_creds.expired = False mock_creds.expired = False
with mock.patch( with (
"google.oauth2.service_account.Credentials.from_service_account_info", mock.patch(
return_value=mock_creds, "google.oauth2.service_account.Credentials.from_service_account_info",
), mock.patch.object(mock_creds, "refresh", return_value=None): return_value=mock_creds,
),
mock.patch.object(mock_creds, "refresh", return_value=None),
):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
location=location, location=location,
@ -518,10 +533,13 @@ class TestIntegrationClient:
mock_credentials, mock_credentials,
): ):
mock_credentials.expired = False mock_credentials.expired = False
with mock.patch( with (
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", mock.patch(
return_value=(mock_credentials, "test_project_id"), "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
), mock.patch.object(mock_credentials, "refresh", return_value=None): return_value=(mock_credentials, "test_project_id"),
),
mock.patch.object(mock_credentials, "refresh", return_value=None),
):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
location=location, location=location,
@ -538,12 +556,15 @@ class TestIntegrationClient:
def test_get_access_token_no_valid_credentials( def test_get_access_token_no_valid_credentials(
self, project, location, integration_name, trigger_name, connection_name self, project, location, integration_name, trigger_name, connection_name
): ):
with mock.patch( with (
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", mock.patch(
return_value=(None, None), "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
), mock.patch( return_value=(None, None),
"google.oauth2.service_account.Credentials.from_service_account_info", ),
return_value=None, mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info",
return_value=None,
),
): ):
client = IntegrationClient( client = IntegrationClient(
project=project, project=project,
@ -587,9 +608,12 @@ class TestIntegrationClient:
service_account_json=None, service_account_json=None,
) )
client.credential_cache = mock_credentials # Simulate a cached credential client.credential_cache = mock_credentials # Simulate a cached credential
with mock.patch("google.auth.default") as mock_default, mock.patch( with (
"google.oauth2.service_account.Credentials.from_service_account_info" mock.patch("google.auth.default") as mock_default,
) as mock_sa: mock.patch(
"google.oauth2.service_account.Credentials.from_service_account_info"
) as mock_sa,
):
token = client._get_access_token() token = client._get_access_token()
assert token == "cached_token" assert token == "cached_token"
mock_default.assert_not_called() mock_default.assert_not_called()

View File

@ -236,8 +236,7 @@ def test_initialization_with_connection_and_actions(
) )
mock_connections_client.return_value.get_connection_details.assert_called_once() mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name, tool_instructions
tool_instructions
) )
mock_openapi_action_spec_parser.return_value.parse.assert_called_once() mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1 assert len(toolset.get_tools()) == 1
@ -390,6 +389,5 @@ def test_initialization_with_connection_details(
tool_instructions=tool_instructions, tool_instructions=tool_instructions,
) )
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name, tool_instructions
tool_instructions
) )

View File

@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.