No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

View File

@@ -196,11 +196,12 @@ class IntegrationClient:
action_details = connections_client.get_action_schema(action)
input_schema = action_details["inputSchema"]
output_schema = action_details["outputSchema"]
action_display_name = action_details["displayName"]
# Remove spaces from the display name to generate valid spec
action_display_name = action_details["displayName"].replace(" ", "")
operation = "EXECUTE_ACTION"
if action == "ExecuteCustomQuery":
connector_spec["components"]["schemas"][
f"{action}_Request"
f"{action_display_name}_Request"
] = connections_client.execute_custom_query_request()
operation = "EXECUTE_QUERY"
else:

View File

@@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling.Automatic function calling works best with'
' automatic function calling. Automatic function calling works best with'
' simpler function signature schema,consider manually parse your'
f' function declaration for function {func_name}.'
)

View File

@@ -11,4 +11,77 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .google_api_tool_sets import calendar_tool_set
__all__ = [
'bigquery_tool_set',
'calendar_tool_set',
'gmail_tool_set',
'youtube_tool_set',
'slides_tool_set',
'sheets_tool_set',
'docs_tool_set',
]
# Nothing is imported here automatically
# Each tool set will only be imported when accessed
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
def __getattr__(name):
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case 'bigquery_tool_set':
if _bigquery_tool_set is None:
from .google_api_tool_sets import bigquery_tool_set as bigquery
_bigquery_tool_set = bigquery
return _bigquery_tool_set
case 'calendar_tool_set':
if _calendar_tool_set is None:
from .google_api_tool_sets import calendar_tool_set as calendar
_calendar_tool_set = calendar
return _calendar_tool_set
case 'gmail_tool_set':
if _gmail_tool_set is None:
from .google_api_tool_sets import gmail_tool_set as gmail
_gmail_tool_set = gmail
return _gmail_tool_set
case 'youtube_tool_set':
if _youtube_tool_set is None:
from .google_api_tool_sets import youtube_tool_set as youtube
_youtube_tool_set = youtube
return _youtube_tool_set
case 'slides_tool_set':
if _slides_tool_set is None:
from .google_api_tool_sets import slides_tool_set as slides
_slides_tool_set = slides
return _slides_tool_set
case 'sheets_tool_set':
if _sheets_tool_set is None:
from .google_api_tool_sets import sheets_tool_set as sheets
_sheets_tool_set = sheets
return _sheets_tool_set
case 'docs_tool_set':
if _docs_tool_set is None:
from .google_api_tool_sets import docs_tool_set as docs
_docs_tool_set = docs
return _docs_tool_set

View File

@@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
logger = logging.getLogger(__name__)
calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
def __getattr__(name):
"""This method dynamically loads and returns GoogleApiToolSet instances for
youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
various Google APIs. It uses a lazy loading approach, initializing each
tool set only when it is first requested. This avoids unnecessary loading
of tool sets that are not used in a given session.
slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
Args:
name (str): The name of the tool set to retrieve (e.g.,
"bigquery_tool_set").
sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
Returns:
GoogleApiToolSet: The requested tool set instance.
docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)
Raises:
AttributeError: If the requested tool set name is not recognized.
"""
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case "bigquery_tool_set":
if _bigquery_tool_set is None:
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
return _bigquery_tool_set
case "calendar_tool_set":
if _calendar_tool_set is None:
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
return _calendar_tool_set
case "gmail_tool_set":
if _gmail_tool_set is None:
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
return _gmail_tool_set
case "youtube_tool_set":
if _youtube_tool_set is None:
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
return _youtube_tool_set
case "slides_tool_set":
if _slides_tool_set is None:
_slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
return _slides_tool_set
case "sheets_tool_set":
if _sheets_tool_set is None:
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
return _sheets_tool_set
case "docs_tool_set":
if _docs_tool_set is None:
_docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)
return _docs_tool_set

View File

@@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
# Determine the actual endpoint path
# Google often has the format something like 'users.messages.list'
rest_path = method_data.get("path", "/")
# flatPath is preferred as it provides the actual path, while path
# might contain variables like {+projectId}
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
if not rest_path.startswith("/"):
rest_path = "/" + rest_path

View File

@@ -16,18 +16,26 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
from ..memory.base_memory_service import MemoryResult
from ..models import LlmRequest
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
"""Loads the memory for the current user."""
"""Loads the memory for the current user.
Args:
query: The query to load the memory for.
Returns:
A list of memory results.
"""
response = tool_context.search_memory(query)
return response.memories
@@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
def __init__(self):
super().__init__(load_memory)
@override
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'query': types.Schema(
type=types.Type.STRING,
)
},
),
)
@override
async def process_llm_request(
self,

View File

@@ -0,0 +1,176 @@
from contextlib import AsyncExitStack
import functools
import sys
from typing import Any, TextIO
import anyio
from pydantic import BaseModel
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
) from e
else:
raise e
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
"""
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
def retry_on_closed_resource(async_reinit_func_name: str):
"""Decorator to automatically reinitialize session and retry action.
When MCP session was closed, the decorator will automatically recreate the
session and retry the action with the same parameters.
Note:
1. async_reinit_func_name is the name of the class member function that
reinitializes the MCP session.
2. Both the decorated function and the async_reinit_func_name must be async
functions.
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
):
async_init_fn = getattr(self, async_reinit_func_name)
await async_init_fn()
else:
raise ValueError(
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
) from reinit_err
return await func(self, *args, **kwargs)
return wrapper
return decorator
class MCPSessionManager:
"""Manages MCP client sessions.
This class provides methods for creating and initializing MCP client sessions,
handling different connection parameters (Stdio and SSE).
"""
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self.connection_params = connection_params
self.exit_stack = exit_stack
self.errlog = errlog
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
Returns:
ClientSession: The initialized MCP client session.
"""
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
elif isinstance(connection_params, SseServerParams):
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session

View File

@@ -17,6 +17,8 @@ from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
@@ -33,6 +35,7 @@ except ImportError as e:
else:
raise e
from ..base_tool import BaseTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
self,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
):
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool
self.mcp_session = mcp_session
self.mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
async def _reinitialize_session(self):
self.mcp_session = await self.mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
Any: The response from the tool.
"""
# TODO(cheliu): Support passing tool context to MCP Server.
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
try:
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e

View File

@@ -13,15 +13,16 @@
# limitations under the License.
from contextlib import AsyncExitStack
import sys
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
from typing import List, Optional, TextIO, Tuple, Type
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult
except ImportError as e:
import sys
@@ -34,18 +35,9 @@ except ImportError as e:
else:
raise e
from pydantic import BaseModel
from .mcp_tool import MCPTool
class SseServerParams(BaseModel):
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
class MCPToolset:
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@@ -110,7 +102,11 @@ class MCPToolset:
"""
def __init__(
self, *, connection_params: StdioServerParameters | SseServerParams
self,
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
exit_stack=AsyncExitStack(),
):
"""Initializes the MCPToolset.
@@ -175,7 +171,14 @@ class MCPToolset:
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params
self.exit_stack = AsyncExitStack()
self.errlog = errlog
self.exit_stack = exit_stack
self.session_manager = MCPSessionManager(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def from_server(
@@ -183,6 +186,7 @@ class MCPToolset:
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
errlog: TextIO = sys.stderr,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
@@ -209,41 +213,27 @@ class MCPToolset:
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
toolset = cls(connection_params=connection_params)
async_exit_stack = async_exit_stack or AsyncExitStack()
toolset = cls(
connection_params=connection_params,
exit_stack=async_exit_stack,
errlog=errlog,
)
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
if isinstance(self.connection_params, StdioServerParameters):
client = stdio_client(self.connection_params)
elif isinstance(self.connection_params, SseServerParams):
client = sse_client(
url=self.connection_params.url,
headers=self.connection_params.headers,
timeout=self.connection_params.timeout,
sse_read_timeout=self.connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self.connection_params}'
)
transports = await self.exit_stack.enter_async_context(client)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*transports)
)
await self.session.initialize()
self.session = await self.session_manager.create_session()
return self.session
async def _exit(self):
"""Closes the connection to MCP Server."""
await self.exit_stack.aclose()
@retry_on_closed_resource('_initialize')
async def load_tools(self) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
@@ -252,7 +242,11 @@ class MCPToolset:
"""
tools_response: ListToolsResult = await self.session.list_tools()
return [
MCPTool(mcp_tool=tool, mcp_session=self.session)
MCPTool(
mcp_tool=tool,
mcp_session=self.session,
mcp_session_manager=self.session_manager,
)
for tool in tools_response.tools
]

View File

@@ -28,7 +28,7 @@ from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ....tools.base_tool import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme