mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
Support async agent and model callbacks
PiperOrigin-RevId: 755542756
This commit is contained in:
parent
f96cdc675c
commit
794a70edcd
@ -14,7 +14,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
import inspect
|
||||||
|
from typing import Any, Awaitable, Union
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import final
|
from typing import final
|
||||||
@ -37,10 +38,15 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||||
|
|
||||||
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
BeforeAgentCallback = Callable[
|
||||||
|
[CallbackContext],
|
||||||
|
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||||
|
]
|
||||||
|
|
||||||
|
AfterAgentCallback = Callable[
|
||||||
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
[CallbackContext],
|
||||||
|
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(BaseModel):
|
class BaseAgent(BaseModel):
|
||||||
@ -119,7 +125,7 @@ class BaseAgent(BaseModel):
|
|||||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||||
ctx = self._create_invocation_context(parent_context)
|
ctx = self._create_invocation_context(parent_context)
|
||||||
|
|
||||||
if event := self.__handle_before_agent_callback(ctx):
|
if event := await self.__handle_before_agent_callback(ctx):
|
||||||
yield event
|
yield event
|
||||||
if ctx.end_invocation:
|
if ctx.end_invocation:
|
||||||
return
|
return
|
||||||
@ -130,7 +136,7 @@ class BaseAgent(BaseModel):
|
|||||||
if ctx.end_invocation:
|
if ctx.end_invocation:
|
||||||
return
|
return
|
||||||
|
|
||||||
if event := self.__handle_after_agent_callback(ctx):
|
if event := await self.__handle_after_agent_callback(ctx):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -230,7 +236,7 @@ class BaseAgent(BaseModel):
|
|||||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||||
return invocation_context
|
return invocation_context
|
||||||
|
|
||||||
def __handle_before_agent_callback(
|
async def __handle_before_agent_callback(
|
||||||
self, ctx: InvocationContext
|
self, ctx: InvocationContext
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
"""Runs the before_agent_callback if it exists.
|
"""Runs the before_agent_callback if it exists.
|
||||||
@ -248,6 +254,9 @@ class BaseAgent(BaseModel):
|
|||||||
callback_context=callback_context
|
callback_context=callback_context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inspect.isawaitable(before_agent_callback_content):
|
||||||
|
before_agent_callback_content = await before_agent_callback_content
|
||||||
|
|
||||||
if before_agent_callback_content:
|
if before_agent_callback_content:
|
||||||
ret_event = Event(
|
ret_event = Event(
|
||||||
invocation_id=ctx.invocation_id,
|
invocation_id=ctx.invocation_id,
|
||||||
@ -269,7 +278,7 @@ class BaseAgent(BaseModel):
|
|||||||
|
|
||||||
return ret_event
|
return ret_event
|
||||||
|
|
||||||
def __handle_after_agent_callback(
|
async def __handle_after_agent_callback(
|
||||||
self, invocation_context: InvocationContext
|
self, invocation_context: InvocationContext
|
||||||
) -> Optional[Event]:
|
) -> Optional[Event]:
|
||||||
"""Runs the after_agent_callback if it exists.
|
"""Runs the after_agent_callback if it exists.
|
||||||
@ -287,6 +296,9 @@ class BaseAgent(BaseModel):
|
|||||||
callback_context=callback_context
|
callback_context=callback_context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inspect.isawaitable(after_agent_callback_content):
|
||||||
|
after_agent_callback_content = await after_agent_callback_content
|
||||||
|
|
||||||
if after_agent_callback_content or callback_context.state.has_delta():
|
if after_agent_callback_content or callback_context.state.has_delta():
|
||||||
ret_event = Event(
|
ret_event = Event(
|
||||||
invocation_id=invocation_context.invocation_id,
|
invocation_id=invocation_context.invocation_id,
|
||||||
|
@ -49,11 +49,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
BeforeModelCallback: TypeAlias = Callable[
|
BeforeModelCallback: TypeAlias = Callable[
|
||||||
[CallbackContext, LlmRequest], Optional[LlmResponse]
|
[CallbackContext, LlmRequest],
|
||||||
|
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||||
]
|
]
|
||||||
AfterModelCallback: TypeAlias = Callable[
|
AfterModelCallback: TypeAlias = Callable[
|
||||||
[CallbackContext, LlmResponse],
|
[CallbackContext, LlmResponse],
|
||||||
Optional[LlmResponse],
|
Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]],
|
||||||
]
|
]
|
||||||
BeforeToolCallback: TypeAlias = Callable[
|
BeforeToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext],
|
[BaseTool, dict[str, Any], ToolContext],
|
||||||
|
@ -16,6 +16,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -199,7 +200,7 @@ class BaseLlmFlow(ABC):
|
|||||||
return "user"
|
return "user"
|
||||||
else:
|
else:
|
||||||
return invocation_context.agent.name
|
return invocation_context.agent.name
|
||||||
|
|
||||||
assert invocation_context.live_request_queue
|
assert invocation_context.live_request_queue
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -447,7 +448,7 @@ class BaseLlmFlow(ABC):
|
|||||||
model_response_event: Event,
|
model_response_event: Event,
|
||||||
) -> AsyncGenerator[LlmResponse, None]:
|
) -> AsyncGenerator[LlmResponse, None]:
|
||||||
# Runs before_model_callback if it exists.
|
# Runs before_model_callback if it exists.
|
||||||
if response := self._handle_before_model_callback(
|
if response := await self._handle_before_model_callback(
|
||||||
invocation_context, llm_request, model_response_event
|
invocation_context, llm_request, model_response_event
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
@ -460,7 +461,7 @@ class BaseLlmFlow(ABC):
|
|||||||
invocation_context.live_request_queue = LiveRequestQueue()
|
invocation_context.live_request_queue = LiveRequestQueue()
|
||||||
async for llm_response in self.run_live(invocation_context):
|
async for llm_response in self.run_live(invocation_context):
|
||||||
# Runs after_model_callback if it exists.
|
# Runs after_model_callback if it exists.
|
||||||
if altered_llm_response := self._handle_after_model_callback(
|
if altered_llm_response := await self._handle_after_model_callback(
|
||||||
invocation_context, llm_response, model_response_event
|
invocation_context, llm_response, model_response_event
|
||||||
):
|
):
|
||||||
llm_response = altered_llm_response
|
llm_response = altered_llm_response
|
||||||
@ -489,14 +490,14 @@ class BaseLlmFlow(ABC):
|
|||||||
llm_response,
|
llm_response,
|
||||||
)
|
)
|
||||||
# Runs after_model_callback if it exists.
|
# Runs after_model_callback if it exists.
|
||||||
if altered_llm_response := self._handle_after_model_callback(
|
if altered_llm_response := await self._handle_after_model_callback(
|
||||||
invocation_context, llm_response, model_response_event
|
invocation_context, llm_response, model_response_event
|
||||||
):
|
):
|
||||||
llm_response = altered_llm_response
|
llm_response = altered_llm_response
|
||||||
|
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
|
||||||
def _handle_before_model_callback(
|
async def _handle_before_model_callback(
|
||||||
self,
|
self,
|
||||||
invocation_context: InvocationContext,
|
invocation_context: InvocationContext,
|
||||||
llm_request: LlmRequest,
|
llm_request: LlmRequest,
|
||||||
@ -514,11 +515,16 @@ class BaseLlmFlow(ABC):
|
|||||||
callback_context = CallbackContext(
|
callback_context = CallbackContext(
|
||||||
invocation_context, event_actions=model_response_event.actions
|
invocation_context, event_actions=model_response_event.actions
|
||||||
)
|
)
|
||||||
return agent.before_model_callback(
|
before_model_callback_content = agent.before_model_callback(
|
||||||
callback_context=callback_context, llm_request=llm_request
|
callback_context=callback_context, llm_request=llm_request
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_after_model_callback(
|
if inspect.isawaitable(before_model_callback_content):
|
||||||
|
before_model_callback_content = await before_model_callback_content
|
||||||
|
|
||||||
|
return before_model_callback_content
|
||||||
|
|
||||||
|
async def _handle_after_model_callback(
|
||||||
self,
|
self,
|
||||||
invocation_context: InvocationContext,
|
invocation_context: InvocationContext,
|
||||||
llm_response: LlmResponse,
|
llm_response: LlmResponse,
|
||||||
@ -536,10 +542,15 @@ class BaseLlmFlow(ABC):
|
|||||||
callback_context = CallbackContext(
|
callback_context = CallbackContext(
|
||||||
invocation_context, event_actions=model_response_event.actions
|
invocation_context, event_actions=model_response_event.actions
|
||||||
)
|
)
|
||||||
return agent.after_model_callback(
|
after_model_callback_content = agent.after_model_callback(
|
||||||
callback_context=callback_context, llm_response=llm_response
|
callback_context=callback_context, llm_response=llm_response
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inspect.isawaitable(after_model_callback_content):
|
||||||
|
after_model_callback_content = await after_model_callback_content
|
||||||
|
|
||||||
|
return after_model_callback_content
|
||||||
|
|
||||||
def _finalize_model_response_event(
|
def _finalize_model_response_event(
|
||||||
self,
|
self,
|
||||||
llm_request: LlmRequest,
|
llm_request: LlmRequest,
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -12,4 +12,4 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from . import agent
|
from . import agent
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import os
|
|||||||
from google.adk.evaluation import AgentEvaluator
|
from google.adk.evaluation import AgentEvaluator
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def agent_eval_artifacts_in_fixture():
|
def agent_eval_artifacts_in_fixture():
|
||||||
"""Get all agents from fixture folder."""
|
"""Get all agents from fixture folder."""
|
||||||
agent_eval_artifacts = []
|
agent_eval_artifacts = []
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -33,16 +33,34 @@ def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_before_agent_callback_noop(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _before_agent_callback_bypass_agent(
|
def _before_agent_callback_bypass_agent(
|
||||||
callback_context: CallbackContext,
|
callback_context: CallbackContext,
|
||||||
) -> types.Content:
|
) -> types.Content:
|
||||||
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
|
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_before_agent_callback_bypass_agent(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> types.Content:
|
||||||
|
return types.Content(parts=[types.Part(text='agent run is bypassed.')])
|
||||||
|
|
||||||
|
|
||||||
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
|
def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_after_agent_callback_noop(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _after_agent_callback_append_agent_reply(
|
def _after_agent_callback_append_agent_reply(
|
||||||
callback_context: CallbackContext,
|
callback_context: CallbackContext,
|
||||||
) -> types.Content:
|
) -> types.Content:
|
||||||
@ -51,6 +69,14 @@ def _after_agent_callback_append_agent_reply(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_after_agent_callback_append_agent_reply(
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
) -> types.Content:
|
||||||
|
return types.Content(
|
||||||
|
parts=[types.Part(text='Agent reply from after agent callback.')]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _IncompleteAgent(BaseAgent):
|
class _IncompleteAgent(BaseAgent):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -158,6 +184,34 @@ async def test_run_async_before_agent_callback_noop(
|
|||||||
spy_run_async_impl.assert_called_once()
|
spy_run_async_impl.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_async_before_agent_callback_noop(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
) -> Union[types.Content, None]:
|
||||||
|
# Arrange
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
before_agent_callback=_async_before_agent_callback_noop,
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||||
|
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
|
||||||
|
|
||||||
|
# Act
|
||||||
|
_ = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
spy_before_agent_callback.assert_called_once()
|
||||||
|
_, kwargs = spy_before_agent_callback.call_args
|
||||||
|
assert 'callback_context' in kwargs
|
||||||
|
assert isinstance(kwargs['callback_context'], CallbackContext)
|
||||||
|
|
||||||
|
spy_run_async_impl.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_before_agent_callback_bypass_agent(
|
async def test_run_async_before_agent_callback_bypass_agent(
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
@ -185,6 +239,33 @@ async def test_run_async_before_agent_callback_bypass_agent(
|
|||||||
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_async_before_agent_callback_bypass_agent(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
before_agent_callback=_async_before_agent_callback_bypass_agent,
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
|
||||||
|
spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
|
||||||
|
|
||||||
|
# Act
|
||||||
|
events = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
spy_before_agent_callback.assert_called_once()
|
||||||
|
spy_run_async_impl.assert_not_called()
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].content.parts[0].text == 'agent run is bypassed.'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_after_agent_callback_noop(
|
async def test_run_async_after_agent_callback_noop(
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
@ -211,6 +292,32 @@ async def test_run_async_after_agent_callback_noop(
|
|||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_async_after_agent_callback_noop(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
mocker: pytest_mock.MockerFixture,
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
after_agent_callback=_async_after_agent_callback_noop,
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
|
||||||
|
|
||||||
|
# Act
|
||||||
|
events = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
spy_after_agent_callback.assert_called_once()
|
||||||
|
_, kwargs = spy_after_agent_callback.call_args
|
||||||
|
assert 'callback_context' in kwargs
|
||||||
|
assert isinstance(kwargs['callback_context'], CallbackContext)
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_after_agent_callback_append_reply(
|
async def test_run_async_after_agent_callback_append_reply(
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
@ -236,6 +343,31 @@ async def test_run_async_after_agent_callback_append_reply(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_async_after_agent_callback_append_reply(
|
||||||
|
request: pytest.FixtureRequest,
|
||||||
|
):
|
||||||
|
# Arrange
|
||||||
|
agent = _TestingAgent(
|
||||||
|
name=f'{request.function.__name__}_test_agent',
|
||||||
|
after_agent_callback=_async_after_agent_callback_append_agent_reply,
|
||||||
|
)
|
||||||
|
parent_ctx = _create_parent_invocation_context(
|
||||||
|
request.function.__name__, agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
events = [e async for e in agent.run_async(parent_ctx)]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[1].author == agent.name
|
||||||
|
assert (
|
||||||
|
events[1].content.parts[0].text
|
||||||
|
== 'Agent reply from after agent callback.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
|
||||||
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
|
||||||
|
@ -56,10 +56,44 @@ class MockAfterModelCallback(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncBeforeModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_request: LlmRequest,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncAfterModelCallback(BaseModel):
|
||||||
|
mock_response: str
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext,
|
||||||
|
llm_response: LlmResponse,
|
||||||
|
) -> LlmResponse:
|
||||||
|
return LlmResponse(
|
||||||
|
content=utils.ModelContent(
|
||||||
|
[types.Part.from_text(text=self.mock_response)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
def noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def async_noop_callback(**kwargs) -> Optional[LlmResponse]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_before_model_callback():
|
async def test_before_model_callback():
|
||||||
responses = ['model_response']
|
responses = ['model_response']
|
||||||
@ -98,26 +132,6 @@ async def test_before_model_callback_noop():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_before_model_callback_end():
|
|
||||||
responses = ['model_response']
|
|
||||||
mock_model = utils.MockModel.create(responses=responses)
|
|
||||||
agent = Agent(
|
|
||||||
name='root_agent',
|
|
||||||
model=mock_model,
|
|
||||||
before_model_callback=MockBeforeModelCallback(
|
|
||||||
mock_response='before_model_callback',
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = utils.TestInMemoryRunner(agent)
|
|
||||||
assert utils.simplify_events(
|
|
||||||
await runner.run_async_with_new_session('test')
|
|
||||||
) == [
|
|
||||||
('root_agent', 'before_model_callback'),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_after_model_callback():
|
async def test_after_model_callback():
|
||||||
responses = ['model_response']
|
responses = ['model_response']
|
||||||
@ -136,3 +150,61 @@ async def test_after_model_callback():
|
|||||||
) == [
|
) == [
|
||||||
('root_agent', 'after_model_callback'),
|
('root_agent', 'after_model_callback'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=MockAsyncBeforeModelCallback(
|
||||||
|
mock_response='async_before_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'async_before_model_callback'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_model_callback_noop():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
before_model_callback=async_noop_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'model_response'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_after_model_callback():
|
||||||
|
responses = ['model_response']
|
||||||
|
mock_model = utils.MockModel.create(responses=responses)
|
||||||
|
agent = Agent(
|
||||||
|
name='root_agent',
|
||||||
|
model=mock_model,
|
||||||
|
after_model_callback=MockAsyncAfterModelCallback(
|
||||||
|
mock_response='async_after_model_callback'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = utils.TestInMemoryRunner(agent)
|
||||||
|
assert utils.simplify_events(
|
||||||
|
await runner.run_async_with_new_session('test')
|
||||||
|
) == [
|
||||||
|
('root_agent', 'async_after_model_callback'),
|
||||||
|
]
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
@ -74,9 +74,12 @@ class TestConnectionsClient:
|
|||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
mock_response.json.return_value = {"data": "test"}
|
mock_response.json.return_value = {"data": "test"}
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
client, "_get_access_token", return_value=mock_credentials.token
|
mock.patch.object(
|
||||||
), mock.patch("requests.get", return_value=mock_response):
|
client, "_get_access_token", return_value=mock_credentials.token
|
||||||
|
),
|
||||||
|
mock.patch("requests.get", return_value=mock_response),
|
||||||
|
):
|
||||||
response = client._execute_api_call("https://test.url")
|
response = client._execute_api_call("https://test.url")
|
||||||
assert response.json() == {"data": "test"}
|
assert response.json() == {"data": "test"}
|
||||||
requests.get.assert_called_once_with(
|
requests.get.assert_called_once_with(
|
||||||
@ -121,9 +124,12 @@ class TestConnectionsClient:
|
|||||||
f"HTTP error {status_code}: {response_text}"
|
f"HTTP error {status_code}: {response_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
client, "_get_access_token", return_value=mock_credentials.token
|
mock.patch.object(
|
||||||
), mock.patch("requests.get", return_value=mock_response):
|
client, "_get_access_token", return_value=mock_credentials.token
|
||||||
|
),
|
||||||
|
mock.patch("requests.get", return_value=mock_response),
|
||||||
|
):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Invalid request. Please check the provided"
|
ValueError, match="Invalid request. Please check the provided"
|
||||||
):
|
):
|
||||||
@ -140,9 +146,12 @@ class TestConnectionsClient:
|
|||||||
"Internal Server Error"
|
"Internal Server Error"
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
client, "_get_access_token", return_value=mock_credentials.token
|
mock.patch.object(
|
||||||
), mock.patch("requests.get", return_value=mock_response):
|
client, "_get_access_token", return_value=mock_credentials.token
|
||||||
|
),
|
||||||
|
mock.patch("requests.get", return_value=mock_response),
|
||||||
|
):
|
||||||
with pytest.raises(ValueError, match="Request error: "):
|
with pytest.raises(ValueError, match="Request error: "):
|
||||||
client._execute_api_call("https://test.url")
|
client._execute_api_call("https://test.url")
|
||||||
|
|
||||||
@ -151,10 +160,13 @@ class TestConnectionsClient:
|
|||||||
):
|
):
|
||||||
credentials = {"email": "test@example.com"}
|
credentials = {"email": "test@example.com"}
|
||||||
client = ConnectionsClient(project, location, connection_name, credentials)
|
client = ConnectionsClient(project, location, connection_name, credentials)
|
||||||
with mock.patch.object(
|
with (
|
||||||
client, "_get_access_token", return_value=mock_credentials.token
|
mock.patch.object(
|
||||||
), mock.patch(
|
client, "_get_access_token", return_value=mock_credentials.token
|
||||||
"requests.get", side_effect=Exception("Something went wrong")
|
),
|
||||||
|
mock.patch(
|
||||||
|
"requests.get", side_effect=Exception("Something went wrong")
|
||||||
|
),
|
||||||
):
|
):
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
Exception, match="An unexpected error occurred: Something went wrong"
|
Exception, match="An unexpected error occurred: Something went wrong"
|
||||||
@ -539,10 +551,13 @@ class TestConnectionsClient:
|
|||||||
mock_creds.token = "sa_token"
|
mock_creds.token = "sa_token"
|
||||||
mock_creds.expired = False
|
mock_creds.expired = False
|
||||||
|
|
||||||
with mock.patch(
|
with (
|
||||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
mock.patch(
|
||||||
return_value=mock_creds,
|
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
return_value=mock_creds,
|
||||||
|
),
|
||||||
|
mock.patch.object(mock_creds, "refresh", return_value=None),
|
||||||
|
):
|
||||||
token = client._get_access_token()
|
token = client._get_access_token()
|
||||||
assert token == "sa_token"
|
assert token == "sa_token"
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
|
google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
|
||||||
@ -555,10 +570,13 @@ class TestConnectionsClient:
|
|||||||
self, project, location, connection_name, mock_credentials
|
self, project, location, connection_name, mock_credentials
|
||||||
):
|
):
|
||||||
client = ConnectionsClient(project, location, connection_name, None)
|
client = ConnectionsClient(project, location, connection_name, None)
|
||||||
with mock.patch(
|
with (
|
||||||
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
mock.patch(
|
||||||
return_value=(mock_credentials, "test_project_id"),
|
"google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
|
||||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
return_value=(mock_credentials, "test_project_id"),
|
||||||
|
),
|
||||||
|
mock.patch.object(mock_credentials, "refresh", return_value=None),
|
||||||
|
):
|
||||||
token = client._get_access_token()
|
token = client._get_access_token()
|
||||||
assert token == "test_token"
|
assert token == "test_token"
|
||||||
|
|
||||||
|
@ -114,11 +114,14 @@ class TestIntegrationClient:
|
|||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
|
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
IntegrationClient,
|
mock.patch.object(
|
||||||
"_get_access_token",
|
IntegrationClient,
|
||||||
return_value=mock_credentials.token,
|
"_get_access_token",
|
||||||
), mock.patch("requests.post", return_value=mock_response):
|
return_value=mock_credentials.token,
|
||||||
|
),
|
||||||
|
mock.patch("requests.post", return_value=mock_response),
|
||||||
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
@ -202,11 +205,14 @@ class TestIntegrationClient:
|
|||||||
f"HTTP error {status_code}: {response_text}"
|
f"HTTP error {status_code}: {response_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
IntegrationClient,
|
mock.patch.object(
|
||||||
"_get_access_token",
|
IntegrationClient,
|
||||||
return_value=mock_credentials.token,
|
"_get_access_token",
|
||||||
), mock.patch("requests.post", return_value=mock_response):
|
return_value=mock_credentials.token,
|
||||||
|
),
|
||||||
|
mock.patch("requests.post", return_value=mock_response),
|
||||||
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
@ -243,11 +249,14 @@ class TestIntegrationClient:
|
|||||||
"Internal Server Error"
|
"Internal Server Error"
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
IntegrationClient,
|
mock.patch.object(
|
||||||
"_get_access_token",
|
IntegrationClient,
|
||||||
return_value=mock_credentials.token,
|
"_get_access_token",
|
||||||
), mock.patch("requests.post", return_value=mock_response):
|
return_value=mock_credentials.token,
|
||||||
|
),
|
||||||
|
mock.patch("requests.post", return_value=mock_response),
|
||||||
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
@ -270,12 +279,15 @@ class TestIntegrationClient:
|
|||||||
mock_credentials,
|
mock_credentials,
|
||||||
mock_connections_client,
|
mock_connections_client,
|
||||||
):
|
):
|
||||||
with mock.patch.object(
|
with (
|
||||||
IntegrationClient,
|
mock.patch.object(
|
||||||
"_get_access_token",
|
IntegrationClient,
|
||||||
return_value=mock_credentials.token,
|
"_get_access_token",
|
||||||
), mock.patch(
|
return_value=mock_credentials.token,
|
||||||
"requests.post", side_effect=Exception("Something went wrong")
|
),
|
||||||
|
mock.patch(
|
||||||
|
"requests.post", side_effect=Exception("Something went wrong")
|
||||||
|
),
|
||||||
):
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
@ -486,10 +498,13 @@ class TestIntegrationClient:
|
|||||||
mock_creds.token = "sa_token"
|
mock_creds.token = "sa_token"
|
||||||
mock_creds.expired = False
|
mock_creds.expired = False
|
||||||
|
|
||||||
with mock.patch(
|
with (
|
||||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
mock.patch(
|
||||||
return_value=mock_creds,
|
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||||
), mock.patch.object(mock_creds, "refresh", return_value=None):
|
return_value=mock_creds,
|
||||||
|
),
|
||||||
|
mock.patch.object(mock_creds, "refresh", return_value=None),
|
||||||
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
@ -518,10 +533,13 @@ class TestIntegrationClient:
|
|||||||
mock_credentials,
|
mock_credentials,
|
||||||
):
|
):
|
||||||
mock_credentials.expired = False
|
mock_credentials.expired = False
|
||||||
with mock.patch(
|
with (
|
||||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
mock.patch(
|
||||||
return_value=(mock_credentials, "test_project_id"),
|
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||||
), mock.patch.object(mock_credentials, "refresh", return_value=None):
|
return_value=(mock_credentials, "test_project_id"),
|
||||||
|
),
|
||||||
|
mock.patch.object(mock_credentials, "refresh", return_value=None),
|
||||||
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
location=location,
|
location=location,
|
||||||
@ -538,12 +556,15 @@ class TestIntegrationClient:
|
|||||||
def test_get_access_token_no_valid_credentials(
|
def test_get_access_token_no_valid_credentials(
|
||||||
self, project, location, integration_name, trigger_name, connection_name
|
self, project, location, integration_name, trigger_name, connection_name
|
||||||
):
|
):
|
||||||
with mock.patch(
|
with (
|
||||||
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
mock.patch(
|
||||||
return_value=(None, None),
|
"google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
|
||||||
), mock.patch(
|
return_value=(None, None),
|
||||||
"google.oauth2.service_account.Credentials.from_service_account_info",
|
),
|
||||||
return_value=None,
|
mock.patch(
|
||||||
|
"google.oauth2.service_account.Credentials.from_service_account_info",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
client = IntegrationClient(
|
client = IntegrationClient(
|
||||||
project=project,
|
project=project,
|
||||||
@ -587,9 +608,12 @@ class TestIntegrationClient:
|
|||||||
service_account_json=None,
|
service_account_json=None,
|
||||||
)
|
)
|
||||||
client.credential_cache = mock_credentials # Simulate a cached credential
|
client.credential_cache = mock_credentials # Simulate a cached credential
|
||||||
with mock.patch("google.auth.default") as mock_default, mock.patch(
|
with (
|
||||||
"google.oauth2.service_account.Credentials.from_service_account_info"
|
mock.patch("google.auth.default") as mock_default,
|
||||||
) as mock_sa:
|
mock.patch(
|
||||||
|
"google.oauth2.service_account.Credentials.from_service_account_info"
|
||||||
|
) as mock_sa,
|
||||||
|
):
|
||||||
token = client._get_access_token()
|
token = client._get_access_token()
|
||||||
assert token == "cached_token"
|
assert token == "cached_token"
|
||||||
mock_default.assert_not_called()
|
mock_default.assert_not_called()
|
||||||
|
@ -236,8 +236,7 @@ def test_initialization_with_connection_and_actions(
|
|||||||
)
|
)
|
||||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name,
|
tool_name, tool_instructions
|
||||||
tool_instructions
|
|
||||||
)
|
)
|
||||||
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||||
assert len(toolset.get_tools()) == 1
|
assert len(toolset.get_tools()) == 1
|
||||||
@ -390,6 +389,5 @@ def test_initialization_with_connection_details(
|
|||||||
tool_instructions=tool_instructions,
|
tool_instructions=tool_instructions,
|
||||||
)
|
)
|
||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name,
|
tool_name, tool_instructions
|
||||||
tool_instructions
|
|
||||||
)
|
)
|
||||||
|
@ -11,4 +11,3 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user