mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-27 07:17:43 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
@@ -196,11 +196,12 @@ class IntegrationClient:
|
||||
action_details = connections_client.get_action_schema(action)
|
||||
input_schema = action_details["inputSchema"]
|
||||
output_schema = action_details["outputSchema"]
|
||||
action_display_name = action_details["displayName"]
|
||||
# Remove spaces from the display name to generate valid spec
|
||||
action_display_name = action_details["displayName"].replace(" ", "")
|
||||
operation = "EXECUTE_ACTION"
|
||||
if action == "ExecuteCustomQuery":
|
||||
connector_spec["components"]["schemas"][
|
||||
f"{action}_Request"
|
||||
f"{action_display_name}_Request"
|
||||
] = connections_client.execute_custom_query_request()
|
||||
operation = "EXECUTE_QUERY"
|
||||
else:
|
||||
|
||||
@@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
|
||||
return schema
|
||||
raise ValueError(
|
||||
f'Failed to parse the parameter {param} of function {func_name} for'
|
||||
' automatic function calling.Automatic function calling works best with'
|
||||
' automatic function calling. Automatic function calling works best with'
|
||||
' simpler function signature schema,consider manually parse your'
|
||||
f' function declaration for function {func_name}.'
|
||||
)
|
||||
|
||||
@@ -11,4 +11,77 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .google_api_tool_sets import calendar_tool_set
|
||||
__all__ = [
|
||||
'bigquery_tool_set',
|
||||
'calendar_tool_set',
|
||||
'gmail_tool_set',
|
||||
'youtube_tool_set',
|
||||
'slides_tool_set',
|
||||
'sheets_tool_set',
|
||||
'docs_tool_set',
|
||||
]
|
||||
|
||||
# Nothing is imported here automatically
|
||||
# Each tool set will only be imported when accessed
|
||||
|
||||
_bigquery_tool_set = None
|
||||
_calendar_tool_set = None
|
||||
_gmail_tool_set = None
|
||||
_youtube_tool_set = None
|
||||
_slides_tool_set = None
|
||||
_sheets_tool_set = None
|
||||
_docs_tool_set = None
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||
|
||||
match name:
|
||||
case 'bigquery_tool_set':
|
||||
if _bigquery_tool_set is None:
|
||||
from .google_api_tool_sets import bigquery_tool_set as bigquery
|
||||
|
||||
_bigquery_tool_set = bigquery
|
||||
return _bigquery_tool_set
|
||||
|
||||
case 'calendar_tool_set':
|
||||
if _calendar_tool_set is None:
|
||||
from .google_api_tool_sets import calendar_tool_set as calendar
|
||||
|
||||
_calendar_tool_set = calendar
|
||||
return _calendar_tool_set
|
||||
|
||||
case 'gmail_tool_set':
|
||||
if _gmail_tool_set is None:
|
||||
from .google_api_tool_sets import gmail_tool_set as gmail
|
||||
|
||||
_gmail_tool_set = gmail
|
||||
return _gmail_tool_set
|
||||
|
||||
case 'youtube_tool_set':
|
||||
if _youtube_tool_set is None:
|
||||
from .google_api_tool_sets import youtube_tool_set as youtube
|
||||
|
||||
_youtube_tool_set = youtube
|
||||
return _youtube_tool_set
|
||||
|
||||
case 'slides_tool_set':
|
||||
if _slides_tool_set is None:
|
||||
from .google_api_tool_sets import slides_tool_set as slides
|
||||
|
||||
_slides_tool_set = slides
|
||||
return _slides_tool_set
|
||||
|
||||
case 'sheets_tool_set':
|
||||
if _sheets_tool_set is None:
|
||||
from .google_api_tool_sets import sheets_tool_set as sheets
|
||||
|
||||
_sheets_tool_set = sheets
|
||||
return _sheets_tool_set
|
||||
|
||||
case 'docs_tool_set':
|
||||
if _docs_tool_set is None:
|
||||
from .google_api_tool_sets import docs_tool_set as docs
|
||||
|
||||
_docs_tool_set = docs
|
||||
return _docs_tool_set
|
||||
|
||||
@@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="calendar",
|
||||
api_version="v3",
|
||||
)
|
||||
_bigquery_tool_set = None
|
||||
_calendar_tool_set = None
|
||||
_gmail_tool_set = None
|
||||
_youtube_tool_set = None
|
||||
_slides_tool_set = None
|
||||
_sheets_tool_set = None
|
||||
_docs_tool_set = None
|
||||
|
||||
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="bigquery",
|
||||
api_version="v2",
|
||||
)
|
||||
|
||||
gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="gmail",
|
||||
api_version="v1",
|
||||
)
|
||||
def __getattr__(name):
|
||||
"""This method dynamically loads and returns GoogleApiToolSet instances for
|
||||
|
||||
youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="youtube",
|
||||
api_version="v3",
|
||||
)
|
||||
various Google APIs. It uses a lazy loading approach, initializing each
|
||||
tool set only when it is first requested. This avoids unnecessary loading
|
||||
of tool sets that are not used in a given session.
|
||||
|
||||
slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="slides",
|
||||
api_version="v1",
|
||||
)
|
||||
Args:
|
||||
name (str): The name of the tool set to retrieve (e.g.,
|
||||
"bigquery_tool_set").
|
||||
|
||||
sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="sheets",
|
||||
api_version="v4",
|
||||
)
|
||||
Returns:
|
||||
GoogleApiToolSet: The requested tool set instance.
|
||||
|
||||
docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="docs",
|
||||
api_version="v1",
|
||||
)
|
||||
Raises:
|
||||
AttributeError: If the requested tool set name is not recognized.
|
||||
"""
|
||||
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||
|
||||
match name:
|
||||
case "bigquery_tool_set":
|
||||
if _bigquery_tool_set is None:
|
||||
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="bigquery",
|
||||
api_version="v2",
|
||||
)
|
||||
|
||||
return _bigquery_tool_set
|
||||
|
||||
case "calendar_tool_set":
|
||||
if _calendar_tool_set is None:
|
||||
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="calendar",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
return _calendar_tool_set
|
||||
|
||||
case "gmail_tool_set":
|
||||
if _gmail_tool_set is None:
|
||||
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="gmail",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _gmail_tool_set
|
||||
|
||||
case "youtube_tool_set":
|
||||
if _youtube_tool_set is None:
|
||||
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="youtube",
|
||||
api_version="v3",
|
||||
)
|
||||
|
||||
return _youtube_tool_set
|
||||
|
||||
case "slides_tool_set":
|
||||
if _slides_tool_set is None:
|
||||
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="slides",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _slides_tool_set
|
||||
|
||||
case "sheets_tool_set":
|
||||
if _sheets_tool_set is None:
|
||||
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="sheets",
|
||||
api_version="v4",
|
||||
)
|
||||
|
||||
return _sheets_tool_set
|
||||
|
||||
case "docs_tool_set":
|
||||
if _docs_tool_set is None:
|
||||
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||
api_name="docs",
|
||||
api_version="v1",
|
||||
)
|
||||
|
||||
return _docs_tool_set
|
||||
|
||||
@@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
|
||||
|
||||
# Determine the actual endpoint path
|
||||
# Google often has the format something like 'users.messages.list'
|
||||
rest_path = method_data.get("path", "/")
|
||||
# flatPath is preferred as it provides the actual path, while path
|
||||
# might contain variables like {+projectId}
|
||||
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
|
||||
if not rest_path.startswith("/"):
|
||||
rest_path = "/" + rest_path
|
||||
|
||||
|
||||
@@ -16,18 +16,26 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .function_tool import FunctionTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import LlmRequest
|
||||
from ..memory.base_memory_service import MemoryResult
|
||||
from ..models import LlmRequest
|
||||
|
||||
|
||||
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
||||
"""Loads the memory for the current user."""
|
||||
"""Loads the memory for the current user.
|
||||
|
||||
Args:
|
||||
query: The query to load the memory for.
|
||||
|
||||
Returns:
|
||||
A list of memory results.
|
||||
"""
|
||||
response = tool_context.search_memory(query)
|
||||
return response.memories
|
||||
|
||||
@@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
|
||||
def __init__(self):
|
||||
super().__init__(load_memory)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
'query': types.Schema(
|
||||
type=types.Type.STRING,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self,
|
||||
|
||||
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from contextlib import AsyncExitStack
|
||||
import functools
|
||||
import sys
|
||||
from typing import Any, TextIO
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||
' version.'
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
"""Parameters for the MCP SSE connection.
|
||||
|
||||
See MCP SSE Client documentation for more details.
|
||||
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
||||
"""
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
def retry_on_closed_resource(async_reinit_func_name: str):
|
||||
"""Decorator to automatically reinitialize session and retry action.
|
||||
|
||||
When MCP session was closed, the decorator will automatically recreate the
|
||||
session and retry the action with the same parameters.
|
||||
|
||||
Note:
|
||||
1. async_reinit_func_name is the name of the class member function that
|
||||
reinitializes the MCP session.
|
||||
2. Both the decorated function and the async_reinit_func_name must be async
|
||||
functions.
|
||||
|
||||
Usage:
|
||||
class MCPTool:
|
||||
...
|
||||
async def create_session(self):
|
||||
self.session = ...
|
||||
|
||||
@retry_on_closed_resource('create_session')
|
||||
async def use_session(self):
|
||||
await self.session.call_tool()
|
||||
|
||||
Args:
|
||||
async_reinit_func_name: The name of the async function to recreate session.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(
|
||||
func
|
||||
) # Preserves original function metadata (name, docstring)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except anyio.ClosedResourceError:
|
||||
try:
|
||||
if hasattr(self, async_reinit_func_name) and callable(
|
||||
getattr(self, async_reinit_func_name)
|
||||
):
|
||||
async_init_fn = getattr(self, async_reinit_func_name)
|
||||
await async_init_fn()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Function {async_reinit_func_name} does not exist in decorated'
|
||||
' class. Please check the function name in'
|
||||
' retry_on_closed_resource decorator.'
|
||||
)
|
||||
except Exception as reinit_err:
|
||||
raise RuntimeError(
|
||||
f'Error reinitializing: {reinit_err}'
|
||||
) from reinit_err
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class MCPSessionManager:
|
||||
"""Manages MCP client sessions.
|
||||
|
||||
This class provides methods for creating and initializing MCP client sessions,
|
||||
handling different connection parameters (Stdio and SSE).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes the MCP session manager.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
mcp_session_manager = MCPSessionManager(
|
||||
connection_params=connection_params,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
session = await mcp_session_manager.create_session()
|
||||
```
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
"""
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = exit_stack
|
||||
self.errlog = errlog
|
||||
|
||||
async def create_session(self) -> ClientSession:
|
||||
return await MCPSessionManager.initialize_session(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(
|
||||
cls,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
exit_stack: AsyncExitStack,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> ClientSession:
|
||||
"""Initializes an MCP client session.
|
||||
|
||||
Args:
|
||||
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||
initializing a local stdio MCP session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The initialized MCP client session.
|
||||
"""
|
||||
if isinstance(connection_params, StdioServerParameters):
|
||||
client = stdio_client(server=connection_params, errlog=errlog)
|
||||
elif isinstance(connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=connection_params.url,
|
||||
headers=connection_params.headers,
|
||||
timeout=connection_params.timeout,
|
||||
sse_read_timeout=connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {connection_params}'
|
||||
)
|
||||
|
||||
transports = await exit_stack.enter_async_context(client)
|
||||
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
||||
await session.initialize()
|
||||
return session
|
||||
@@ -17,6 +17,8 @@ from typing import Optional
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
@@ -33,6 +35,7 @@ except ImportError as e:
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
from ..base_tool import BaseTool
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
|
||||
self,
|
||||
mcp_tool: McpBaseTool,
|
||||
mcp_session: ClientSession,
|
||||
mcp_session_manager: MCPSessionManager,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
auth_credential: Optional[AuthCredential] | None = None,
|
||||
):
|
||||
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
|
||||
self.description = mcp_tool.description if mcp_tool.description else ""
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_session = mcp_session
|
||||
self.mcp_session_manager = mcp_session_manager
|
||||
# TODO(cheliu): Support passing auth to MCP Server.
|
||||
self.auth_scheme = auth_scheme
|
||||
self.auth_credential = auth_credential
|
||||
|
||||
async def _reinitialize_session(self):
|
||||
self.mcp_session = await self.mcp_session_manager.create_session()
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Gets the function declaration for the tool.
|
||||
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
@retry_on_closed_resource("_reinitialize_session")
|
||||
async def run_async(self, *, args, tool_context: ToolContext):
|
||||
"""Runs the tool asynchronously.
|
||||
|
||||
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
|
||||
Any: The response from the tool.
|
||||
"""
|
||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
try:
|
||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, TextIO, Tuple, Type
|
||||
|
||||
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
||||
|
||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||
# their Python version to 3.10 if it fails.
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import ListToolsResult
|
||||
except ImportError as e:
|
||||
import sys
|
||||
@@ -34,18 +35,9 @@ except ImportError as e:
|
||||
else:
|
||||
raise e
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .mcp_tool import MCPTool
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class MCPToolset:
|
||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||
|
||||
@@ -110,7 +102,11 @@ class MCPToolset:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, connection_params: StdioServerParameters | SseServerParams
|
||||
self,
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
errlog: TextIO = sys.stderr,
|
||||
exit_stack=AsyncExitStack(),
|
||||
):
|
||||
"""Initializes the MCPToolset.
|
||||
|
||||
@@ -175,7 +171,14 @@ class MCPToolset:
|
||||
if not connection_params:
|
||||
raise ValueError('Missing connection params in MCPToolset.')
|
||||
self.connection_params = connection_params
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.errlog = errlog
|
||||
self.exit_stack = exit_stack
|
||||
|
||||
self.session_manager = MCPSessionManager(
|
||||
connection_params=self.connection_params,
|
||||
exit_stack=self.exit_stack,
|
||||
errlog=self.errlog,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_server(
|
||||
@@ -183,6 +186,7 @@ class MCPToolset:
|
||||
*,
|
||||
connection_params: StdioServerParameters | SseServerParams,
|
||||
async_exit_stack: Optional[AsyncExitStack] = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
||||
"""Retrieve all tools from the MCP connection.
|
||||
|
||||
@@ -209,41 +213,27 @@ class MCPToolset:
|
||||
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
||||
connection when server shuts down.
|
||||
"""
|
||||
toolset = cls(connection_params=connection_params)
|
||||
async_exit_stack = async_exit_stack or AsyncExitStack()
|
||||
toolset = cls(
|
||||
connection_params=connection_params,
|
||||
exit_stack=async_exit_stack,
|
||||
errlog=errlog,
|
||||
)
|
||||
|
||||
await async_exit_stack.enter_async_context(toolset)
|
||||
tools = await toolset.load_tools()
|
||||
return (tools, async_exit_stack)
|
||||
|
||||
async def _initialize(self) -> ClientSession:
|
||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||
if isinstance(self.connection_params, StdioServerParameters):
|
||||
client = stdio_client(self.connection_params)
|
||||
elif isinstance(self.connection_params, SseServerParams):
|
||||
client = sse_client(
|
||||
url=self.connection_params.url,
|
||||
headers=self.connection_params.headers,
|
||||
timeout=self.connection_params.timeout,
|
||||
sse_read_timeout=self.connection_params.sse_read_timeout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unable to initialize connection. Connection should be'
|
||||
' StdioServerParameters or SseServerParams, but got'
|
||||
f' {self.connection_params}'
|
||||
)
|
||||
|
||||
transports = await self.exit_stack.enter_async_context(client)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*transports)
|
||||
)
|
||||
await self.session.initialize()
|
||||
self.session = await self.session_manager.create_session()
|
||||
return self.session
|
||||
|
||||
async def _exit(self):
|
||||
"""Closes the connection to MCP Server."""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
@retry_on_closed_resource('_initialize')
|
||||
async def load_tools(self) -> List[MCPTool]:
|
||||
"""Loads all tools from the MCP Server.
|
||||
|
||||
@@ -252,7 +242,11 @@ class MCPToolset:
|
||||
"""
|
||||
tools_response: ListToolsResult = await self.session.list_tools()
|
||||
return [
|
||||
MCPTool(mcp_tool=tool, mcp_session=self.session)
|
||||
MCPTool(
|
||||
mcp_tool=tool,
|
||||
mcp_session=self.session,
|
||||
mcp_session_manager=self.session_manager,
|
||||
)
|
||||
for tool in tools_response.tools
|
||||
]
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing_extensions import override
|
||||
|
||||
from ....auth.auth_credential import AuthCredential
|
||||
from ....auth.auth_schemes import AuthScheme
|
||||
from ....tools import BaseTool
|
||||
from ....tools.base_tool import BaseTool
|
||||
from ...tool_context import ToolContext
|
||||
from ..auth.auth_helpers import credential_to_param
|
||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||
|
||||
Reference in New Issue
Block a user