diff --git a/contributing/samples/mcp_agent/agent.py b/contributing/samples/mcp_agent/agent.py index 9626cc6..6f8dcca 100755 --- a/contributing/samples/mcp_agent/agent.py +++ b/contributing/samples/mcp_agent/agent.py @@ -34,8 +34,12 @@ root_agent = LlmAgent( ], ), # don't want agent to do write operation - tool_predicate=lambda tool, ctx=None: tool.name - not in ('write_file', 'edit_file', 'create_directory', 'move_file'), + tool_filter=[ + 'write_file', + 'edit_file', + 'create_directory', + 'move_file', + ], ) ], ) diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 10ea2fa..cc4ad02 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -1,12 +1,13 @@ from abc import ABC from abc import abstractmethod -from typing import Optional +from typing import Optional, runtime_checkable from typing import Protocol from google.adk.agents.readonly_context import ReadonlyContext from google.adk.tools.base_tool import BaseTool +@runtime_checkable class ToolPredicate(Protocol): """Base class for a predicate that defines the interface to decide whether a diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 4beb25a..6cca5a0 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -97,6 +97,18 @@ class MCPToolset(BaseToolset): self.session = await self.session_manager.create_session() return self.session + def _is_selected( + self, tool: ..., readonly_context: Optional[ReadonlyContext] + ) -> bool: + """Checks if a tool should be selected based on the tool filter.""" + if self.tool_filter is None: + return True + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + return False + @override async def close(self): """Closes the connection to MCP Server.""" @@ -123,5 +135,5 @@ class MCPToolset(BaseToolset): mcp_session_manager=self.session_manager, ) for tool in tools_response.tools - if self.tool_filter is None or self.tool_filter(tool, readonly_context) + if self._is_selected(tool, readonly_context) ]