structure saas with tools
This commit is contained in:
85
.venv/lib/python3.10/site-packages/mcp/client/__main__.py
Normal file
85
.venv/lib/python3.10/site-packages/mcp/client/__main__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(
|
||||
command=command_or_url, args=args, env=env_dict
|
||||
)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command_or_url", help="Command or URL to connect to")
|
||||
parser.add_argument("args", nargs="*", help="Additional arguments")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--env",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
378
.venv/lib/python3.10/site-packages/mcp/client/session.py
Normal file
378
.venv/lib/python3.10/site-packages/mcp/client/session.py
Normal file
@@ -0,0 +1,378 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
||||
types.ClientResult | types.ErrorData
|
||||
)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(
|
||||
"Unsupported protocol version from the server: "
|
||||
f"{result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, str] | None = None
|
||||
) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
async def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
async def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return await responder.respond(
|
||||
types.ClientResult(root=types.EmptyResult())
|
||||
)
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(
|
||||
self, notification: types.ServerNotification
|
||||
) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
146
.venv/lib/python3.10/site-packages/mcp/client/sse.py
Normal file
146
.venv/lib/python3.10/site-packages/mcp/client/sse.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = (
|
||||
"Endpoint origin does not match "
|
||||
f"connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
case _:
|
||||
logger.warning(
|
||||
f"Unknown SSE event: {sse.event}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
"Client message sent successfully: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
216
.venv/lib/python3.10/site-packages/mcp/client/stdio/__init__.py
Normal file
216
.venv/lib/python3.10/site-packages/mcp/client/stdio/__init__.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal, TextIO
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
from .win32 import (
|
||||
create_windows_process,
|
||||
get_windows_executable_command,
|
||||
terminate_windows_process,
|
||||
)
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if value.startswith("()"):
|
||||
# Skip functions, which are a security risk
|
||||
continue
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
cwd: str | Path | None = None
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
"""
|
||||
The text encoding used when sending/receiving messages to the server
|
||||
|
||||
defaults to utf-8
|
||||
"""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
process = await _create_platform_compatible_process(
|
||||
command=command,
|
||||
args=server.args,
|
||||
env=(
|
||||
{**get_default_environment(), **server.env}
|
||||
if server.env is not None
|
||||
else get_default_environment()
|
||||
),
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# Clean up process to prevent any dangling orphaned processes
|
||||
if sys.platform == "win32":
|
||||
await terminate_windows_process(process)
|
||||
else:
|
||||
process.terminate()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for the current platform.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Platform-appropriate command
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
return get_windows_executable_command(command)
|
||||
else:
|
||||
return command
|
||||
|
||||
|
||||
async def _create_platform_compatible_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a platform-compatible way.
|
||||
Returns a process handle.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||
else:
|
||||
process = await anyio.open_process(
|
||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
||||
)
|
||||
|
||||
return process
|
||||
Binary file not shown.
Binary file not shown.
109
.venv/lib/python3.10/site-packages/mcp/client/stdio/win32.py
Normal file
109
.venv/lib/python3.10/site-packages/mcp/client/stdio/win32.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Windows-specific functionality for stdio client operations.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import anyio
|
||||
from anyio.abc import Process
|
||||
|
||||
|
||||
def get_windows_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for Windows.
|
||||
|
||||
On Windows, commands might exist with specific extensions (.exe, .cmd, etc.)
|
||||
that need to be located for proper execution.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Windows-appropriate command path
|
||||
"""
|
||||
try:
|
||||
# First check if command exists in PATH as-is
|
||||
if command_path := shutil.which(command):
|
||||
return command_path
|
||||
|
||||
# Check for Windows-specific extensions
|
||||
for ext in [".cmd", ".bat", ".exe", ".ps1"]:
|
||||
ext_version = f"{command}{ext}"
|
||||
if ext_path := shutil.which(ext_version):
|
||||
return ext_path
|
||||
|
||||
# For regular commands or if we couldn't find special versions
|
||||
return command
|
||||
except OSError:
|
||||
# Handle file system errors during path resolution
|
||||
# (permissions, broken symlinks, etc.)
|
||||
return command
|
||||
|
||||
|
||||
async def create_windows_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a Windows-compatible way.
|
||||
|
||||
Windows processes need special handling for console windows and
|
||||
process creation flags.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
args: Command line arguments
|
||||
env: Environment variables
|
||||
errlog: Where to send stderr output
|
||||
cwd: Working directory for the process
|
||||
|
||||
Returns:
|
||||
A process handle
|
||||
"""
|
||||
try:
|
||||
# Try with Windows-specific flags to hide console window
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
# Ensure we don't create console windows for each process
|
||||
creationflags=subprocess.CREATE_NO_WINDOW # type: ignore
|
||||
if hasattr(subprocess, "CREATE_NO_WINDOW")
|
||||
else 0,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
)
|
||||
return process
|
||||
except Exception:
|
||||
# Don't raise, let's try to create the process without creation flags
|
||||
process = await anyio.open_process(
|
||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
async def terminate_windows_process(process: Process):
|
||||
"""
|
||||
Terminate a Windows process.
|
||||
|
||||
Note: On Windows, terminating a process with process.terminate() doesn't
|
||||
always guarantee immediate process termination.
|
||||
So we give it 2s to exit, or we call process.kill()
|
||||
which sends a SIGKILL equivalent signal.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
"""
|
||||
try:
|
||||
process.terminate()
|
||||
with anyio.fail_after(2.0):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Force kill if it doesn't terminate
|
||||
process.kill()
|
||||
89
.venv/lib/python3.10/site-packages/mcp/client/websocket.py
Normal file
89
.venv/lib/python3.10/site-packages/mcp/client/websocket.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
WebSocket client transport for MCP, symmetrical to the server version.
|
||||
|
||||
Connects to 'url' using the 'mcp' subprotocol, then yields:
|
||||
(read_stream, write_stream)
|
||||
|
||||
- read_stream: As you read from this stream, you'll receive either valid
|
||||
JSONRPCMessage objects or Exception objects (when validation fails).
|
||||
- write_stream: Write JSONRPCMessage objects to this stream to send them
|
||||
over the WebSocket to the server.
|
||||
"""
|
||||
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
# Connect using websockets, requesting the "mcp" subprotocol
|
||||
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
|
||||
|
||||
async def ws_reader():
|
||||
"""
|
||||
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
|
||||
and sends them into read_stream_writer.
|
||||
"""
|
||||
async with read_stream_writer:
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
await read_stream_writer.send(message)
|
||||
except ValidationError as exc:
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def ws_writer():
|
||||
"""
|
||||
Reads JSON-RPC messages from write_stream_reader and
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start reader and writer tasks
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
|
||||
# Yield the receive/send streams
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
# Once the caller's 'async with' block exits, we shut down
|
||||
tg.cancel_scope.cancel()
|
||||
Reference in New Issue
Block a user