refactor: update callback type signatures to support sync and async responses

This commit is contained in:
Alankrit Verma 2025-04-29 09:02:09 -04:00
parent 504aa6ba06
commit fcbf57466e
3 changed files with 82 additions and 83 deletions

View File

@ -57,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
] ]
BeforeToolCallback: TypeAlias = Callable[ BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext], [BaseTool, dict[str, Any], ToolContext],
Awaitable[Optional[dict]], Union[Awaitable[Optional[dict]], Optional[dict]],
] ]
AfterToolCallback: TypeAlias = Callable[ AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict], [BaseTool, dict[str, Any], ToolContext, dict],
Awaitable[Optional[dict]], Union[Awaitable[Optional[dict]], Optional[dict]],
] ]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

View File

@ -151,36 +151,33 @@ async def handle_function_calls_async(
# do not use "args" as the variable name, because it is a reserved keyword # do not use "args" as the variable name, because it is a reserved keyword
# in python debugger. # in python debugger.
function_args = function_call.args or {} function_args = function_call.args or {}
function_response = None function_response: Optional[dict] = None
# # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback: # before_tool_callback (sync or async)
# function_response = agent.before_tool_callback(
# tool=tool, args=function_args, tool_context=tool_context
# )
# Short-circuit via before_tool_callback (sync *or* async)
if agent.before_tool_callback: if agent.before_tool_callback:
_maybe = agent.before_tool_callback( function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context tool=tool, args=function_args, tool_context=tool_context
) )
if inspect.isawaitable(_maybe): if inspect.isawaitable(function_response):
_maybe = await _maybe function_response = await function_response
function_response = _maybe
if not function_response: if not function_response:
function_response = await __call_tool_async( function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context tool, args=function_args, tool_context=tool_context
) )
# Calls after_tool_callback if it exists.
# after_tool_callback (sync or async)
if agent.after_tool_callback: if agent.after_tool_callback:
_maybe2 = agent.after_tool_callback( altered_function_response = agent.after_tool_callback(
tool=tool, tool=tool,
args=function_args, args=function_args,
tool_context=tool_context, tool_context=tool_context,
tool_response=function_response, tool_response=function_response,
) )
if inspect.isawaitable(_maybe2): if inspect.isawaitable(altered_function_response):
_maybe2 = await _maybe2 altered_function_response = await altered_function_response
if _maybe2 is not None: if altered_function_response is not None:
function_response = _maybe2 function_response = altered_function_response
if tool.is_long_running: if tool.is_long_running:
# Allow long running function to return None to not provide function response. # Allow long running function to return None to not provide function response.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import pytest import pytest
@ -27,81 +27,83 @@ from ... import utils
class AsyncBeforeToolCallback: class AsyncBeforeToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__( def __init__(self, mock_response: Dict[str, Any]):
self, self.mock_response = mock_response
tool: FunctionTool,
args: Dict[str, Any], async def __call__(
tool_context: ToolContext, self,
) -> Optional[Dict[str, Any]]: tool: FunctionTool,
return self.mock_response args: Dict[str, Any],
tool_context: ToolContext,
) -> Optional[Dict[str, Any]]:
return self.mock_response
class AsyncAfterToolCallback: class AsyncAfterToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__( def __init__(self, mock_response: Dict[str, Any]):
self, self.mock_response = mock_response
tool: FunctionTool,
args: Dict[str, Any], async def __call__(
tool_context: ToolContext, self,
tool_response: Dict[str, Any], tool: FunctionTool,
) -> Optional[Dict[str, Any]]: args: Dict[str, Any],
return self.mock_response tool_context: ToolContext,
tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
return self.mock_response
async def invoke_tool_with_callbacks( async def invoke_tool_with_callbacks(
before_cb=None, after_cb=None before_cb=None, after_cb=None
) -> Optional[Event]: ) -> Optional[Event]:
def simple_fn(**kwargs) -> Dict[str, Any]: def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"} return {"initial": "response"}
tool = FunctionTool(simple_fn) tool = FunctionTool(simple_fn)
model = utils.MockModel.create(responses=[]) model = utils.MockModel.create(responses=[])
agent = Agent( agent = Agent(
name="agent", name="agent",
model=model, model=model,
tools=[tool], tools=[tool],
before_tool_callback=before_cb, before_tool_callback=before_cb,
after_tool_callback=after_cb, after_tool_callback=after_cb,
) )
invocation_context = utils.create_invocation_context( invocation_context = utils.create_invocation_context(
agent=agent, user_content="" agent=agent, user_content=""
) )
# Build function call event # Build function call event
function_call = types.FunctionCall(name=tool.name, args={}) function_call = types.FunctionCall(name=tool.name, args={})
content = types.Content(parts=[types.Part(function_call=function_call)]) content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event( event = Event(
invocation_id=invocation_context.invocation_id, invocation_id=invocation_context.invocation_id,
author=agent.name, author=agent.name,
content=content, content=content,
) )
tools_dict = {tool.name: tool} tools_dict = {tool.name: tool}
return await handle_function_calls_async( return await handle_function_calls_async(
invocation_context, invocation_context,
event, event,
tools_dict, tools_dict,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_before_tool_callback(): async def test_async_before_tool_callback():
mock_resp = {"test": "before_tool_callback"} mock_resp = {"test": "before_tool_callback"}
before_cb = AsyncBeforeToolCallback(mock_resp) before_cb = AsyncBeforeToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(before_cb=before_cb) result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
assert result_event is not None assert result_event is not None
part = result_event.content.parts[0] part = result_event.content.parts[0]
assert part.function_response.response == mock_resp assert part.function_response.response == mock_resp
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_after_tool_callback(): async def test_async_after_tool_callback():
mock_resp = {"test": "after_tool_callback"} mock_resp = {"test": "after_tool_callback"}
after_cb = AsyncAfterToolCallback(mock_resp) after_cb = AsyncAfterToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(after_cb=after_cb) result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
assert result_event is not None assert result_event is not None
part = result_event.content.parts[0] part = result_event.content.parts[0]
assert part.function_response.response == mock_resp assert part.function_response.response == mock_resp