refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners

PiperOrigin-RevId: 762463028
This commit is contained in:
Xiang (Sean) Zhou
2025-05-23 09:48:48 -07:00
committed by Copybara-Service
parent b9b2c3fb54
commit 92c37496d3
6 changed files with 239 additions and 406 deletions
+27 -24
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from google.genai.types import FunctionDeclaration
@@ -23,7 +25,6 @@ from .mcp_session_manager import retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession
from mcp.types import Tool as McpBaseTool
except ImportError as e:
import sys
@@ -43,6 +44,8 @@ from ..base_tool import BaseTool
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from ..tool_context import ToolContext
logger = logging.getLogger("google_adk." + __name__)
class MCPTool(BaseTool):
"""Turns a MCP Tool into a Vertex Agent Framework Tool.
@@ -53,44 +56,40 @@ class MCPTool(BaseTool):
def __init__(
self,
*,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
auth_credential: Optional[AuthCredential] = None,
):
"""Initializes a MCPTool.
This tool wraps a MCP Tool interface and an active MCP Session. It invokes
the MCP Tool through executing the tool from remote MCP Session.
Example:
tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
This tool wraps a MCP Tool interface and uses a session manager to
communicate with the MCP server.
Args:
mcp_tool: The MCP tool to wrap.
mcp_session: The MCP session to use to call the tool.
mcp_session_manager: The MCP session manager to use for communication.
auth_scheme: The authentication scheme to use.
auth_credential: The authentication credential to use.
Raises:
ValueError: If mcp_tool or mcp_session is None.
ValueError: If mcp_tool or mcp_session_manager is None.
"""
if mcp_tool is None:
raise ValueError("mcp_tool cannot be None")
if mcp_session is None:
raise ValueError("mcp_session cannot be None")
super().__init__(name=mcp_tool.name, description=mcp_tool.description or "")
if mcp_session_manager is None:
raise ValueError("mcp_session_manager cannot be None")
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description if mcp_tool.description else "",
)
self._mcp_tool = mcp_tool
self._mcp_session = mcp_session
self._mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
async def _reinitialize_session(self):
self._mcp_session = await self._mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@@ -105,7 +104,6 @@ class MCPTool(BaseTool):
)
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@@ -117,10 +115,15 @@ class MCPTool(BaseTool):
Returns:
Any: The response from the tool.
"""
# Get the session from the session manager
session = await self._mcp_session_manager.create_session()
# TODO(cheliu): Support passing tool context to MCP Server.
try:
response = await self._mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e
response = await session.call_tool(self.name, arguments=args)
return response
async def _reinitialize_session(self):
"""Reinitializes the session when connection is lost."""
# Close the old session and create a new one
await self._mcp_session_manager.close()
await self._mcp_session_manager.create_session()