structure saas with tools
This commit is contained in:
114
.venv/lib/python3.10/site-packages/mcp/__init__.py
Normal file
114
.venv/lib/python3.10/site-packages/mcp/__init__.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from .client.session import ClientSession
|
||||
from .client.stdio import StdioServerParameters, stdio_client
|
||||
from .server.session import ServerSession
|
||||
from .server.stdio import stdio_server
|
||||
from .shared.exceptions import McpError
|
||||
from .types import (
|
||||
CallToolRequest,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteRequest,
|
||||
CreateMessageRequest,
|
||||
CreateMessageResult,
|
||||
ErrorData,
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResult,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
LoggingMessageNotification,
|
||||
Notification,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
PromptsCapability,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourcesCapability,
|
||||
ResourceUpdatedNotification,
|
||||
RootsCapability,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
StopReason,
|
||||
SubscribeRequest,
|
||||
Tool,
|
||||
ToolsCapability,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
from .types import (
|
||||
Role as SamplingRole,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallToolRequest",
|
||||
"ClientCapabilities",
|
||||
"ClientNotification",
|
||||
"ClientRequest",
|
||||
"ClientResult",
|
||||
"ClientSession",
|
||||
"CreateMessageRequest",
|
||||
"CreateMessageResult",
|
||||
"ErrorData",
|
||||
"GetPromptRequest",
|
||||
"GetPromptResult",
|
||||
"Implementation",
|
||||
"IncludeContext",
|
||||
"InitializeRequest",
|
||||
"InitializeResult",
|
||||
"InitializedNotification",
|
||||
"JSONRPCError",
|
||||
"JSONRPCRequest",
|
||||
"ListPromptsRequest",
|
||||
"ListPromptsResult",
|
||||
"ListResourcesRequest",
|
||||
"ListResourcesResult",
|
||||
"ListToolsResult",
|
||||
"LoggingLevel",
|
||||
"LoggingMessageNotification",
|
||||
"McpError",
|
||||
"Notification",
|
||||
"PingRequest",
|
||||
"ProgressNotification",
|
||||
"PromptsCapability",
|
||||
"ReadResourceRequest",
|
||||
"ReadResourceResult",
|
||||
"ResourcesCapability",
|
||||
"ResourceUpdatedNotification",
|
||||
"Resource",
|
||||
"RootsCapability",
|
||||
"SamplingMessage",
|
||||
"SamplingRole",
|
||||
"ServerCapabilities",
|
||||
"ServerNotification",
|
||||
"ServerRequest",
|
||||
"ServerResult",
|
||||
"ServerSession",
|
||||
"SetLevelRequest",
|
||||
"StdioServerParameters",
|
||||
"StopReason",
|
||||
"SubscribeRequest",
|
||||
"Tool",
|
||||
"ToolsCapability",
|
||||
"UnsubscribeRequest",
|
||||
"stdio_client",
|
||||
"stdio_server",
|
||||
"CompleteRequest",
|
||||
"JSONRPCResponse",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
6
.venv/lib/python3.10/site-packages/mcp/cli/__init__.py
Normal file
6
.venv/lib/python3.10/site-packages/mcp/cli/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""FastMCP CLI package."""
|
||||
|
||||
from .cli import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
142
.venv/lib/python3.10/site-packages/mcp/cli/claude.py
Normal file
142
.venv/lib/python3.10/site-packages/mcp/cli/claude.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Claude app integration utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MCP_PACKAGE = "mcp[cli]"
|
||||
|
||||
|
||||
def get_claude_config_path() -> Path | None:
|
||||
"""Get the Claude config directory based on platform."""
|
||||
if sys.platform == "win32":
|
||||
path = Path(Path.home(), "AppData", "Roaming", "Claude")
|
||||
elif sys.platform == "darwin":
|
||||
path = Path(Path.home(), "Library", "Application Support", "Claude")
|
||||
elif sys.platform.startswith("linux"):
|
||||
path = Path(
|
||||
os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if path.exists():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def update_claude_config(
|
||||
file_spec: str,
|
||||
server_name: str,
|
||||
*,
|
||||
with_editable: Path | None = None,
|
||||
with_packages: list[str] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
) -> bool:
|
||||
"""Add or update a FastMCP server in Claude's configuration.
|
||||
|
||||
Args:
|
||||
file_spec: Path to the server file, optionally with :object suffix
|
||||
server_name: Name for the server in Claude's config
|
||||
with_editable: Optional directory to install in editable mode
|
||||
with_packages: Optional list of additional packages to install
|
||||
env_vars: Optional dictionary of environment variables. These are merged with
|
||||
any existing variables, with new values taking precedence.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Claude Desktop's config directory is not found, indicating
|
||||
Claude Desktop may not be installed or properly set up.
|
||||
"""
|
||||
config_dir = get_claude_config_path()
|
||||
if not config_dir:
|
||||
raise RuntimeError(
|
||||
"Claude Desktop config directory not found. Please ensure Claude Desktop"
|
||||
" is installed and has been run at least once to initialize its config."
|
||||
)
|
||||
|
||||
config_file = config_dir / "claude_desktop_config.json"
|
||||
if not config_file.exists():
|
||||
try:
|
||||
config_file.write_text("{}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create Claude config file",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"config_file": str(config_file),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
if "mcpServers" not in config:
|
||||
config["mcpServers"] = {}
|
||||
|
||||
# Always preserve existing env vars and merge with new ones
|
||||
if (
|
||||
server_name in config["mcpServers"]
|
||||
and "env" in config["mcpServers"][server_name]
|
||||
):
|
||||
existing_env = config["mcpServers"][server_name]["env"]
|
||||
if env_vars:
|
||||
# New vars take precedence over existing ones
|
||||
env_vars = {**existing_env, **env_vars}
|
||||
else:
|
||||
env_vars = existing_env
|
||||
|
||||
# Build uv run command
|
||||
args = ["run"]
|
||||
|
||||
# Collect all packages in a set to deduplicate
|
||||
packages = {MCP_PACKAGE}
|
||||
if with_packages:
|
||||
packages.update(pkg for pkg in with_packages if pkg)
|
||||
|
||||
# Add all packages with --with
|
||||
for pkg in sorted(packages):
|
||||
args.extend(["--with", pkg])
|
||||
|
||||
if with_editable:
|
||||
args.extend(["--with-editable", str(with_editable)])
|
||||
|
||||
# Convert file path to absolute before adding to command
|
||||
# Split off any :object suffix first
|
||||
if ":" in file_spec:
|
||||
file_path, server_object = file_spec.rsplit(":", 1)
|
||||
file_spec = f"{Path(file_path).resolve()}:{server_object}"
|
||||
else:
|
||||
file_spec = str(Path(file_spec).resolve())
|
||||
|
||||
# Add fastmcp run command
|
||||
args.extend(["mcp", "run", file_spec])
|
||||
|
||||
server_config: dict[str, Any] = {"command": "uv", "args": args}
|
||||
|
||||
# Add environment variables if specified
|
||||
if env_vars:
|
||||
server_config["env"] = env_vars
|
||||
|
||||
config["mcpServers"][server_name] = server_config
|
||||
|
||||
config_file.write_text(json.dumps(config, indent=2))
|
||||
logger.info(
|
||||
f"Added server '{server_name}' to Claude config",
|
||||
extra={"config_file": str(config_file)},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update Claude config",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"config_file": str(config_file),
|
||||
},
|
||||
)
|
||||
return False
|
||||
470
.venv/lib/python3.10/site-packages/mcp/cli/cli.py
Normal file
470
.venv/lib/python3.10/site-packages/mcp/cli/cli.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""MCP CLI tools."""
|
||||
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
try:
|
||||
import typer
|
||||
except ImportError:
|
||||
print("Error: typer is required. Install with 'pip install mcp[cli]'")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from mcp.cli import claude
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
except ImportError:
|
||||
print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
except ImportError:
|
||||
dotenv = None
|
||||
|
||||
logger = get_logger("cli")
|
||||
|
||||
app = typer.Typer(
|
||||
name="mcp",
|
||||
help="MCP development tools",
|
||||
add_completion=False,
|
||||
no_args_is_help=True, # Show help if no args provided
|
||||
)
|
||||
|
||||
|
||||
def _get_npx_command():
|
||||
"""Get the correct npx command for the current platform."""
|
||||
if sys.platform == "win32":
|
||||
# Try both npx.cmd and npx.exe on Windows
|
||||
for cmd in ["npx.cmd", "npx.exe", "npx"]:
|
||||
try:
|
||||
subprocess.run(
|
||||
[cmd, "--version"], check=True, capture_output=True, shell=True
|
||||
)
|
||||
return cmd
|
||||
except subprocess.CalledProcessError:
|
||||
continue
|
||||
return None
|
||||
return "npx" # On Unix-like systems, just use npx
|
||||
|
||||
|
||||
def _parse_env_var(env_var: str) -> tuple[str, str]:
|
||||
"""Parse environment variable string in format KEY=VALUE."""
|
||||
if "=" not in env_var:
|
||||
logger.error(
|
||||
f"Invalid environment variable format: {env_var}. Must be KEY=VALUE"
|
||||
)
|
||||
sys.exit(1)
|
||||
key, value = env_var.split("=", 1)
|
||||
return key.strip(), value.strip()
|
||||
|
||||
|
||||
def _build_uv_command(
|
||||
file_spec: str,
|
||||
with_editable: Path | None = None,
|
||||
with_packages: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Build the uv run command that runs a MCP server through mcp run."""
|
||||
cmd = ["uv"]
|
||||
|
||||
cmd.extend(["run", "--with", "mcp"])
|
||||
|
||||
if with_editable:
|
||||
cmd.extend(["--with-editable", str(with_editable)])
|
||||
|
||||
if with_packages:
|
||||
for pkg in with_packages:
|
||||
if pkg:
|
||||
cmd.extend(["--with", pkg])
|
||||
|
||||
# Add mcp run command
|
||||
cmd.extend(["mcp", "run", file_spec])
|
||||
return cmd
|
||||
|
||||
|
||||
def _parse_file_path(file_spec: str) -> tuple[Path, str | None]:
|
||||
"""Parse a file path that may include a server object specification.
|
||||
|
||||
Args:
|
||||
file_spec: Path to file, optionally with :object suffix
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, server_object)
|
||||
"""
|
||||
# First check if we have a Windows path (e.g., C:\...)
|
||||
has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":"
|
||||
|
||||
# Split on the last colon, but only if it's not part of the Windows drive letter
|
||||
# and there's actually another colon in the string after the drive letter
|
||||
if ":" in (file_spec[2:] if has_windows_drive else file_spec):
|
||||
file_str, server_object = file_spec.rsplit(":", 1)
|
||||
else:
|
||||
file_str, server_object = file_spec, None
|
||||
|
||||
# Resolve the file path
|
||||
file_path = Path(file_str).expanduser().resolve()
|
||||
if not file_path.exists():
|
||||
logger.error(f"File not found: {file_path}")
|
||||
sys.exit(1)
|
||||
if not file_path.is_file():
|
||||
logger.error(f"Not a file: {file_path}")
|
||||
sys.exit(1)
|
||||
|
||||
return file_path, server_object
|
||||
|
||||
|
||||
def _import_server(file: Path, server_object: str | None = None):
|
||||
"""Import a MCP server from a file.
|
||||
|
||||
Args:
|
||||
file: Path to the file
|
||||
server_object: Optional object name in format "module:object" or just "object"
|
||||
|
||||
Returns:
|
||||
The server object
|
||||
"""
|
||||
# Add parent directory to Python path so imports can be resolved
|
||||
file_dir = str(file.parent)
|
||||
if file_dir not in sys.path:
|
||||
sys.path.insert(0, file_dir)
|
||||
|
||||
# Import the module
|
||||
spec = importlib.util.spec_from_file_location("server_module", file)
|
||||
if not spec or not spec.loader:
|
||||
logger.error("Could not load module", extra={"file": str(file)})
|
||||
sys.exit(1)
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# If no object specified, try common server names
|
||||
if not server_object:
|
||||
# Look for the most common server object names
|
||||
for name in ["mcp", "server", "app"]:
|
||||
if hasattr(module, name):
|
||||
return getattr(module, name)
|
||||
|
||||
logger.error(
|
||||
f"No server object found in {file}. Please either:\n"
|
||||
"1. Use a standard variable name (mcp, server, or app)\n"
|
||||
"2. Specify the object name with file:object syntax",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Handle module:object syntax
|
||||
if ":" in server_object:
|
||||
module_name, object_name = server_object.split(":", 1)
|
||||
try:
|
||||
server_module = importlib.import_module(module_name)
|
||||
server = getattr(server_module, object_name, None)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
f"Could not import module '{module_name}'",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Just object name
|
||||
server = getattr(module, server_object, None)
|
||||
|
||||
if server is None:
|
||||
logger.error(
|
||||
f"Server object '{server_object}' not found",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
return server
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Show the MCP version."""
|
||||
try:
|
||||
version = importlib.metadata.version("mcp")
|
||||
print(f"MCP version {version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
print("MCP version unknown (package not installed)")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dev(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
with_editable: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--with-editable",
|
||||
"-e",
|
||||
help="Directory containing pyproject.toml to install in editable mode",
|
||||
exists=True,
|
||||
file_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
with_packages: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--with",
|
||||
help="Additional packages to install",
|
||||
),
|
||||
] = [],
|
||||
) -> None:
|
||||
"""Run a MCP server with the MCP Inspector."""
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Starting dev server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_object": server_object,
|
||||
"with_editable": str(with_editable) if with_editable else None,
|
||||
"with_packages": with_packages,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Import server to get dependencies
|
||||
server = _import_server(file, server_object)
|
||||
if hasattr(server, "dependencies"):
|
||||
with_packages = list(set(with_packages + server.dependencies))
|
||||
|
||||
uv_cmd = _build_uv_command(file_spec, with_editable, with_packages)
|
||||
|
||||
# Get the correct npx command
|
||||
npx_cmd = _get_npx_command()
|
||||
if not npx_cmd:
|
||||
logger.error(
|
||||
"npx not found. Please ensure Node.js and npm are properly installed "
|
||||
"and added to your system PATH."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Run the MCP Inspector command with shell=True on Windows
|
||||
shell = sys.platform == "win32"
|
||||
process = subprocess.run(
|
||||
[npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd,
|
||||
check=True,
|
||||
shell=shell,
|
||||
env=dict(os.environ.items()), # Convert to list of tuples for env update
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(
|
||||
"Dev server failed",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"error": str(e),
|
||||
"returncode": e.returncode,
|
||||
},
|
||||
)
|
||||
sys.exit(e.returncode)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"npx not found. Please ensure Node.js and npm are properly installed "
|
||||
"and added to your system PATH. You may need to restart your terminal "
|
||||
"after installation.",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
transport: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--transport",
|
||||
"-t",
|
||||
help="Transport protocol to use (stdio or sse)",
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Run a MCP server.
|
||||
|
||||
The server can be specified in two ways:\n
|
||||
1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n
|
||||
2. Import approach: server.py:app - imports and runs the specified server object.\n\n
|
||||
|
||||
Note: This command runs the server directly. You are responsible for ensuring
|
||||
all dependencies are available.\n
|
||||
For dependency management, use `mcp install` or `mcp dev` instead.
|
||||
""" # noqa: E501
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Running server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_object": server_object,
|
||||
"transport": transport,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Import and get server object
|
||||
server = _import_server(file, server_object)
|
||||
|
||||
# Run the server
|
||||
kwargs = {}
|
||||
if transport:
|
||||
kwargs["transport"] = transport
|
||||
|
||||
server.run(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to run server: {e}",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def install(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
server_name: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--name",
|
||||
"-n",
|
||||
help="Custom name for the server (defaults to server's name attribute or"
|
||||
" file name)",
|
||||
),
|
||||
] = None,
|
||||
with_editable: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--with-editable",
|
||||
"-e",
|
||||
help="Directory containing pyproject.toml to install in editable mode",
|
||||
exists=True,
|
||||
file_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
with_packages: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--with",
|
||||
help="Additional packages to install",
|
||||
),
|
||||
] = [],
|
||||
env_vars: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--env-var",
|
||||
"-v",
|
||||
help="Environment variables in KEY=VALUE format",
|
||||
),
|
||||
] = [],
|
||||
env_file: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--env-file",
|
||||
"-f",
|
||||
help="Load environment variables from a .env file",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Install a MCP server in the Claude desktop app.
|
||||
|
||||
Environment variables are preserved once added and only updated if new values
|
||||
are explicitly provided.
|
||||
"""
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Installing server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_name": server_name,
|
||||
"server_object": server_object,
|
||||
"with_editable": str(with_editable) if with_editable else None,
|
||||
"with_packages": with_packages,
|
||||
},
|
||||
)
|
||||
|
||||
if not claude.get_claude_config_path():
|
||||
logger.error("Claude app not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Try to import server to get its name, but fall back to file name if dependencies
|
||||
# missing
|
||||
name = server_name
|
||||
server = None
|
||||
if not name:
|
||||
try:
|
||||
server = _import_server(file, server_object)
|
||||
name = server.name
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Could not import server (likely missing dependencies), using file"
|
||||
" name",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
name = file.stem
|
||||
|
||||
# Get server dependencies if available
|
||||
server_dependencies = getattr(server, "dependencies", []) if server else []
|
||||
if server_dependencies:
|
||||
with_packages = list(set(with_packages + server_dependencies))
|
||||
|
||||
# Process environment variables if provided
|
||||
env_dict: dict[str, str] | None = None
|
||||
if env_file or env_vars:
|
||||
env_dict = {}
|
||||
# Load from .env file if specified
|
||||
if env_file:
|
||||
if dotenv:
|
||||
try:
|
||||
env_dict |= {
|
||||
k: v
|
||||
for k, v in dotenv.dotenv_values(env_file).items()
|
||||
if v is not None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load .env file: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("python-dotenv is not installed. Cannot load .env file.")
|
||||
sys.exit(1)
|
||||
|
||||
# Add command line environment variables
|
||||
for env_var in env_vars:
|
||||
key, value = _parse_env_var(env_var)
|
||||
env_dict[key] = value
|
||||
|
||||
if claude.update_claude_config(
|
||||
file_spec,
|
||||
name,
|
||||
with_editable=with_editable,
|
||||
with_packages=with_packages,
|
||||
env_vars=env_dict,
|
||||
):
|
||||
logger.info(f"Successfully installed {name} in Claude app")
|
||||
else:
|
||||
logger.error(f"Failed to install {name} in Claude app")
|
||||
sys.exit(1)
|
||||
85
.venv/lib/python3.10/site-packages/mcp/client/__main__.py
Normal file
85
.venv/lib/python3.10/site-packages/mcp/client/__main__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, message_handler=message_handler
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(
|
||||
command=command_or_url, args=args, env=env_dict
|
||||
)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command_or_url", help="Command or URL to connect to")
|
||||
parser.add_argument("args", nargs="*", help="Additional arguments")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--env",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
378
.venv/lib/python3.10/site-packages/mcp/client/session.py
Normal file
378
.venv/lib/python3.10/site-packages/mcp/client/session.py
Normal file
@@ -0,0 +1,378 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import BaseSession, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
||||
types.ClientResult | types.ErrorData
|
||||
)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(
|
||||
"Unsupported protocol version from the server: "
|
||||
f"{result.protocolVersion}"
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.InitializedNotification(method="notifications/initialized")
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, str] | None = None
|
||||
) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
async def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
async def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return await responder.respond(
|
||||
types.ClientResult(root=types.EmptyResult())
|
||||
)
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(
|
||||
self, notification: types.ServerNotification
|
||||
) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
146
.venv/lib/python3.10/site-packages/mcp/client/sse.py
Normal file
146
.venv/lib/python3.10/site-packages/mcp/client/sse.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse():
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(
|
||||
f"Received endpoint URL: {endpoint_url}"
|
||||
)
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme
|
||||
!= endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = (
|
||||
"Endpoint origin does not match "
|
||||
f"connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(
|
||||
f"Received server message: {message}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error parsing server message: {exc}"
|
||||
)
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
case _:
|
||||
logger.warning(
|
||||
f"Unknown SSE event: {sse.event}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in sse_reader: {exc}")
|
||||
await read_stream_writer.send(exc)
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(
|
||||
"Client message sent successfully: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in post_writer: {exc}")
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.info(
|
||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
||||
)
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
216
.venv/lib/python3.10/site-packages/mcp/client/stdio/__init__.py
Normal file
216
.venv/lib/python3.10/site-packages/mcp/client/stdio/__init__.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal, TextIO
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
from .win32 import (
|
||||
create_windows_process,
|
||||
get_windows_executable_command,
|
||||
terminate_windows_process,
|
||||
)
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if value.startswith("()"):
|
||||
# Skip functions, which are a security risk
|
||||
continue
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
cwd: str | Path | None = None
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
"""
|
||||
The text encoding used when sending/receiving messages to the server
|
||||
|
||||
defaults to utf-8
|
||||
"""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
process = await _create_platform_compatible_process(
|
||||
command=command,
|
||||
args=server.args,
|
||||
env=(
|
||||
{**get_default_environment(), **server.env}
|
||||
if server.env is not None
|
||||
else get_default_environment()
|
||||
),
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# Clean up process to prevent any dangling orphaned processes
|
||||
if sys.platform == "win32":
|
||||
await terminate_windows_process(process)
|
||||
else:
|
||||
process.terminate()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for the current platform.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Platform-appropriate command
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
return get_windows_executable_command(command)
|
||||
else:
|
||||
return command
|
||||
|
||||
|
||||
async def _create_platform_compatible_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a platform-compatible way.
|
||||
Returns a process handle.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||
else:
|
||||
process = await anyio.open_process(
|
||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
||||
)
|
||||
|
||||
return process
|
||||
Binary file not shown.
Binary file not shown.
109
.venv/lib/python3.10/site-packages/mcp/client/stdio/win32.py
Normal file
109
.venv/lib/python3.10/site-packages/mcp/client/stdio/win32.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Windows-specific functionality for stdio client operations.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import anyio
|
||||
from anyio.abc import Process
|
||||
|
||||
|
||||
def get_windows_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for Windows.
|
||||
|
||||
On Windows, commands might exist with specific extensions (.exe, .cmd, etc.)
|
||||
that need to be located for proper execution.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Windows-appropriate command path
|
||||
"""
|
||||
try:
|
||||
# First check if command exists in PATH as-is
|
||||
if command_path := shutil.which(command):
|
||||
return command_path
|
||||
|
||||
# Check for Windows-specific extensions
|
||||
for ext in [".cmd", ".bat", ".exe", ".ps1"]:
|
||||
ext_version = f"{command}{ext}"
|
||||
if ext_path := shutil.which(ext_version):
|
||||
return ext_path
|
||||
|
||||
# For regular commands or if we couldn't find special versions
|
||||
return command
|
||||
except OSError:
|
||||
# Handle file system errors during path resolution
|
||||
# (permissions, broken symlinks, etc.)
|
||||
return command
|
||||
|
||||
|
||||
async def create_windows_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a Windows-compatible way.
|
||||
|
||||
Windows processes need special handling for console windows and
|
||||
process creation flags.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
args: Command line arguments
|
||||
env: Environment variables
|
||||
errlog: Where to send stderr output
|
||||
cwd: Working directory for the process
|
||||
|
||||
Returns:
|
||||
A process handle
|
||||
"""
|
||||
try:
|
||||
# Try with Windows-specific flags to hide console window
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
# Ensure we don't create console windows for each process
|
||||
creationflags=subprocess.CREATE_NO_WINDOW # type: ignore
|
||||
if hasattr(subprocess, "CREATE_NO_WINDOW")
|
||||
else 0,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
)
|
||||
return process
|
||||
except Exception:
|
||||
# Don't raise, let's try to create the process without creation flags
|
||||
process = await anyio.open_process(
|
||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
async def terminate_windows_process(process: Process):
|
||||
"""
|
||||
Terminate a Windows process.
|
||||
|
||||
Note: On Windows, terminating a process with process.terminate() doesn't
|
||||
always guarantee immediate process termination.
|
||||
So we give it 2s to exit, or we call process.kill()
|
||||
which sends a SIGKILL equivalent signal.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
"""
|
||||
try:
|
||||
process.terminate()
|
||||
with anyio.fail_after(2.0):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Force kill if it doesn't terminate
|
||||
process.kill()
|
||||
89
.venv/lib/python3.10/site-packages/mcp/client/websocket.py
Normal file
89
.venv/lib/python3.10/site-packages/mcp/client/websocket.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
WebSocket client transport for MCP, symmetrical to the server version.
|
||||
|
||||
Connects to 'url' using the 'mcp' subprotocol, then yields:
|
||||
(read_stream, write_stream)
|
||||
|
||||
- read_stream: As you read from this stream, you'll receive either valid
|
||||
JSONRPCMessage objects or Exception objects (when validation fails).
|
||||
- write_stream: Write JSONRPCMessage objects to this stream to send them
|
||||
over the WebSocket to the server.
|
||||
"""
|
||||
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
# Connect using websockets, requesting the "mcp" subprotocol
|
||||
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
|
||||
|
||||
async def ws_reader():
|
||||
"""
|
||||
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
|
||||
and sends them into read_stream_writer.
|
||||
"""
|
||||
async with read_stream_writer:
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
await read_stream_writer.send(message)
|
||||
except ValidationError as exc:
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def ws_writer():
|
||||
"""
|
||||
Reads JSON-RPC messages from write_stream_reader and
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = message.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start reader and writer tasks
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
|
||||
# Yield the receive/send streams
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
# Once the caller's 'async with' block exits, we shut down
|
||||
tg.cancel_scope.cancel()
|
||||
0
.venv/lib/python3.10/site-packages/mcp/py.typed
Normal file
0
.venv/lib/python3.10/site-packages/mcp/py.typed
Normal file
@@ -0,0 +1,5 @@
|
||||
from .fastmcp import FastMCP
|
||||
from .lowlevel import NotificationOptions, Server
|
||||
from .models import InitializationOptions
|
||||
|
||||
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]
|
||||
50
.venv/lib/python3.10/site-packages/mcp/server/__main__.py
Normal file
50
.venv/lib/python3.10/site-packages/mcp/server/__main__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import ServerCapabilities
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
version = importlib.metadata.version("mcp")
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with (
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="mcp",
|
||||
server_version=version,
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as session,
|
||||
write_stream,
|
||||
):
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
"""FastMCP - A more ergonomic interface for MCP servers."""
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from .server import Context, FastMCP
|
||||
from .utilities.types import Image
|
||||
|
||||
__version__ = version("mcp")
|
||||
__all__ = ["FastMCP", "Context", "Image"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""Custom exceptions for FastMCP."""
|
||||
|
||||
|
||||
class FastMCPError(Exception):
|
||||
"""Base error for FastMCP."""
|
||||
|
||||
|
||||
class ValidationError(FastMCPError):
|
||||
"""Error in validating parameters or return values."""
|
||||
|
||||
|
||||
class ResourceError(FastMCPError):
|
||||
"""Error in resource operations."""
|
||||
|
||||
|
||||
class ToolError(FastMCPError):
|
||||
"""Error in tool operations."""
|
||||
|
||||
|
||||
class InvalidSignature(Exception):
|
||||
"""Invalid signature for use with FastMCP."""
|
||||
@@ -0,0 +1,4 @@
|
||||
from .base import Prompt
|
||||
from .manager import PromptManager
|
||||
|
||||
__all__ = ["Prompt", "PromptManager"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,167 @@
|
||||
"""Base classes for FastMCP prompts."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
||||
|
||||
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base class for all prompt messages."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: CONTENT_TYPES
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""A message from the user."""
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class AssistantMessage(Message):
|
||||
"""A message from the assistant."""
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
||||
UserMessage | AssistantMessage
|
||||
)
|
||||
|
||||
SyncPromptResult = (
|
||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
)
|
||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
"""An argument that can be passed to a prompt."""
|
||||
|
||||
name: str = Field(description="Name of the argument")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the argument does"
|
||||
)
|
||||
required: bool = Field(
|
||||
default=False, description="Whether the argument is required"
|
||||
)
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt template that can be rendered with parameters."""
|
||||
|
||||
name: str = Field(description="Name of the prompt")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the prompt does"
|
||||
)
|
||||
arguments: list[PromptArgument] | None = Field(
|
||||
None, description="Arguments that can be passed to the prompt"
|
||||
)
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> "Prompt":
|
||||
"""Create a Prompt from a function.
|
||||
|
||||
The function can return:
|
||||
- A string (converted to a message)
|
||||
- A Message object
|
||||
- A dict (converted to a message)
|
||||
- A sequence of any of the above
|
||||
"""
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# Get schema from TypeAdapter - will fail if function isn't properly typed
|
||||
parameters = TypeAdapter(fn).json_schema()
|
||||
|
||||
# Convert parameters to PromptArguments
|
||||
arguments: list[PromptArgument] = []
|
||||
if "properties" in parameters:
|
||||
for param_name, param in parameters["properties"].items():
|
||||
required = param_name in parameters.get("required", [])
|
||||
arguments.append(
|
||||
PromptArgument(
|
||||
name=param_name,
|
||||
description=param.get("description"),
|
||||
required=required,
|
||||
)
|
||||
)
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
name=func_name,
|
||||
description=description or fn.__doc__ or "",
|
||||
arguments=arguments,
|
||||
fn=fn,
|
||||
)
|
||||
|
||||
async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
|
||||
"""Render the prompt with arguments."""
|
||||
# Validate required arguments
|
||||
if self.arguments:
|
||||
required = {arg.name for arg in self.arguments if arg.required}
|
||||
provided = set(arguments or {})
|
||||
missing = required - provided
|
||||
if missing:
|
||||
raise ValueError(f"Missing required arguments: {missing}")
|
||||
|
||||
try:
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**(arguments or {}))
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
# Validate messages
|
||||
if not isinstance(result, list | tuple):
|
||||
result = [result]
|
||||
|
||||
# Convert result to messages
|
||||
messages: list[Message] = []
|
||||
for msg in result: # type: ignore[reportUnknownVariableType]
|
||||
try:
|
||||
if isinstance(msg, Message):
|
||||
messages.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
messages.append(message_validator.validate_python(msg))
|
||||
elif isinstance(msg, str):
|
||||
content = TextContent(type="text", text=msg)
|
||||
messages.append(UserMessage(content=content))
|
||||
else:
|
||||
content = json.dumps(pydantic_core.to_jsonable_python(msg))
|
||||
messages.append(Message(role="user", content=content))
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not convert prompt result to message: {msg}"
|
||||
)
|
||||
|
||||
return messages
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error rendering prompt {self.name}: {e}")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Prompt management functionality."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp.prompts.base import Message, Prompt
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Manages FastMCP prompts."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_prompts: bool = True):
|
||||
self._prompts: dict[str, Prompt] = {}
|
||||
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
|
||||
|
||||
def get_prompt(self, name: str) -> Prompt | None:
|
||||
"""Get prompt by name."""
|
||||
return self._prompts.get(name)
|
||||
|
||||
def list_prompts(self) -> list[Prompt]:
|
||||
"""List all registered prompts."""
|
||||
return list(self._prompts.values())
|
||||
|
||||
def add_prompt(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
) -> Prompt:
|
||||
"""Add a prompt to the manager."""
|
||||
|
||||
# Check for duplicates
|
||||
existing = self._prompts.get(prompt.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_prompts:
|
||||
logger.warning(f"Prompt already exists: {prompt.name}")
|
||||
return existing
|
||||
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def render_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> list[Message]:
|
||||
"""Render a prompt by name with arguments."""
|
||||
prompt = self.get_prompt(name)
|
||||
if not prompt:
|
||||
raise ValueError(f"Unknown prompt: {name}")
|
||||
|
||||
return await prompt.render(arguments)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Prompt management functionality."""
|
||||
|
||||
from mcp.server.fastmcp.prompts.base import Prompt
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Manages FastMCP prompts."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_prompts: bool = True):
|
||||
self._prompts: dict[str, Prompt] = {}
|
||||
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
|
||||
|
||||
def add_prompt(self, prompt: Prompt) -> Prompt:
|
||||
"""Add a prompt to the manager."""
|
||||
logger.debug(f"Adding prompt: {prompt.name}")
|
||||
existing = self._prompts.get(prompt.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_prompts:
|
||||
logger.warning(f"Prompt already exists: {prompt.name}")
|
||||
return existing
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
def get_prompt(self, name: str) -> Prompt | None:
|
||||
"""Get prompt by name."""
|
||||
return self._prompts.get(name)
|
||||
|
||||
def list_prompts(self) -> list[Prompt]:
|
||||
"""List all registered prompts."""
|
||||
return list(self._prompts.values())
|
||||
@@ -0,0 +1,23 @@
|
||||
from .base import Resource
|
||||
from .resource_manager import ResourceManager
|
||||
from .templates import ResourceTemplate
|
||||
from .types import (
|
||||
BinaryResource,
|
||||
DirectoryResource,
|
||||
FileResource,
|
||||
FunctionResource,
|
||||
HttpResource,
|
||||
TextResource,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Resource",
|
||||
"TextResource",
|
||||
"BinaryResource",
|
||||
"FunctionResource",
|
||||
"FileResource",
|
||||
"HttpResource",
|
||||
"DirectoryResource",
|
||||
"ResourceTemplate",
|
||||
"ResourceManager",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
"""Base classes and interfaces for FastMCP resources."""
|
||||
|
||||
import abc
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
UrlConstraints,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class Resource(BaseModel, abc.ABC):
|
||||
"""Base class for all resources."""
|
||||
|
||||
model_config = ConfigDict(validate_default=True)
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
|
||||
default=..., description="URI of the resource"
|
||||
)
|
||||
name: str | None = Field(description="Name of the resource", default=None)
|
||||
description: str | None = Field(
|
||||
description="Description of the resource", default=None
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$",
|
||||
)
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
|
||||
"""Set default name from URI if not provided."""
|
||||
if name:
|
||||
return name
|
||||
if uri := info.data.get("uri"):
|
||||
return str(uri)
|
||||
raise ValueError("Either name or uri must be provided")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource content."""
|
||||
pass
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Resource manager functionality."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
from mcp.server.fastmcp.resources.templates import ResourceTemplate
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResourceManager:
|
||||
"""Manages FastMCP resources."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_resources: bool = True):
|
||||
self._resources: dict[str, Resource] = {}
|
||||
self._templates: dict[str, ResourceTemplate] = {}
|
||||
self.warn_on_duplicate_resources = warn_on_duplicate_resources
|
||||
|
||||
def add_resource(self, resource: Resource) -> Resource:
|
||||
"""Add a resource to the manager.
|
||||
|
||||
Args:
|
||||
resource: A Resource instance to add
|
||||
|
||||
Returns:
|
||||
The added resource. If a resource with the same URI already exists,
|
||||
returns the existing resource.
|
||||
"""
|
||||
logger.debug(
|
||||
"Adding resource",
|
||||
extra={
|
||||
"uri": resource.uri,
|
||||
"type": type(resource).__name__,
|
||||
"resource_name": resource.name,
|
||||
},
|
||||
)
|
||||
existing = self._resources.get(str(resource.uri))
|
||||
if existing:
|
||||
if self.warn_on_duplicate_resources:
|
||||
logger.warning(f"Resource already exists: {resource.uri}")
|
||||
return existing
|
||||
self._resources[str(resource.uri)] = resource
|
||||
return resource
|
||||
|
||||
def add_template(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> ResourceTemplate:
|
||||
"""Add a template from a function."""
|
||||
template = ResourceTemplate.from_function(
|
||||
fn,
|
||||
uri_template=uri_template,
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
self._templates[template.uri_template] = template
|
||||
return template
|
||||
|
||||
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
|
||||
"""Get resource by URI, checking concrete resources first, then templates."""
|
||||
uri_str = str(uri)
|
||||
logger.debug("Getting resource", extra={"uri": uri_str})
|
||||
|
||||
# First check concrete resources
|
||||
if resource := self._resources.get(uri_str):
|
||||
return resource
|
||||
|
||||
# Then check templates
|
||||
for template in self._templates.values():
|
||||
if params := template.matches(uri_str):
|
||||
try:
|
||||
return await template.create_resource(uri_str, params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
|
||||
raise ValueError(f"Unknown resource: {uri}")
|
||||
|
||||
def list_resources(self) -> list[Resource]:
|
||||
"""List all registered resources."""
|
||||
logger.debug("Listing resources", extra={"count": len(self._resources)})
|
||||
return list(self._resources.values())
|
||||
|
||||
def list_templates(self) -> list[ResourceTemplate]:
|
||||
"""List all registered templates."""
|
||||
logger.debug("Listing templates", extra={"count": len(self._templates)})
|
||||
return list(self._templates.values())
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Resource template functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template for dynamically creating resources."""
|
||||
|
||||
uri_template: str = Field(
|
||||
description="URI template with parameters (e.g. weather://{city}/current)"
|
||||
)
|
||||
name: str = Field(description="Name of the resource")
|
||||
description: str | None = Field(description="Description of what the resource does")
|
||||
mime_type: str = Field(
|
||||
default="text/plain", description="MIME type of the resource content"
|
||||
)
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
parameters: dict[str, Any] = Field(
|
||||
description="JSON schema for function parameters"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> ResourceTemplate:
|
||||
"""Create a template from a function."""
|
||||
func_name = name or fn.__name__
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# Get schema from TypeAdapter - will fail if function isn't properly typed
|
||||
parameters = TypeAdapter(fn).json_schema()
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
uri_template=uri_template,
|
||||
name=func_name,
|
||||
description=description or fn.__doc__ or "",
|
||||
mime_type=mime_type or "text/plain",
|
||||
fn=fn,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def matches(self, uri: str) -> dict[str, Any] | None:
|
||||
"""Check if URI matches template and extract parameters."""
|
||||
# Convert template to regex pattern
|
||||
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
|
||||
match = re.match(f"^{pattern}$", uri)
|
||||
if match:
|
||||
return match.groupdict()
|
||||
return None
|
||||
|
||||
async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
|
||||
"""Create a resource from the template with the given parameters."""
|
||||
try:
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**params)
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
return FunctionResource(
|
||||
uri=uri, # type: ignore
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
mime_type=self.mime_type,
|
||||
fn=lambda: result, # Capture result in closure
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Concrete resource implementations."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
import httpx
|
||||
import pydantic.json
|
||||
import pydantic_core
|
||||
from pydantic import Field, ValidationInfo
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
|
||||
|
||||
class TextResource(Resource):
|
||||
"""A resource that reads from a string."""
|
||||
|
||||
text: str = Field(description="Text content of the resource")
|
||||
|
||||
async def read(self) -> str:
|
||||
"""Read the text content."""
|
||||
return self.text
|
||||
|
||||
|
||||
class BinaryResource(Resource):
|
||||
"""A resource that reads from bytes."""
|
||||
|
||||
data: bytes = Field(description="Binary content of the resource")
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read the binary content."""
|
||||
return self.data
|
||||
|
||||
|
||||
class FunctionResource(Resource):
|
||||
"""A resource that defers data loading by wrapping a function.
|
||||
|
||||
The function is only called when the resource is read, allowing for lazy loading
|
||||
of potentially expensive data. This is particularly useful when listing resources,
|
||||
as the function won't be called until the resource is actually accessed.
|
||||
|
||||
The function can return:
|
||||
- str for text content (default)
|
||||
- bytes for binary content
|
||||
- other types will be converted to JSON
|
||||
"""
|
||||
|
||||
fn: Callable[[], Any] = Field(exclude=True)
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource by calling the wrapped function."""
|
||||
try:
|
||||
result = (
|
||||
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||
)
|
||||
if isinstance(result, Resource):
|
||||
return await result.read()
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.dumps(pydantic_core.to_jsonable_python(result))
|
||||
except (TypeError, pydantic_core.PydanticSerializationError):
|
||||
# If JSON serialization fails, try str()
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading resource {self.uri}: {e}")
|
||||
|
||||
|
||||
class FileResource(Resource):
|
||||
"""A resource that reads from a file.
|
||||
|
||||
Set is_binary=True to read file as binary data instead of text.
|
||||
"""
|
||||
|
||||
path: Path = Field(description="Path to the file")
|
||||
is_binary: bool = Field(
|
||||
default=False,
|
||||
description="Whether to read the file as binary data",
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
)
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path:
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
@pydantic.field_validator("is_binary")
|
||||
@classmethod
|
||||
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
|
||||
"""Set is_binary based on mime_type if not explicitly set."""
|
||||
if is_binary:
|
||||
return True
|
||||
mime_type = info.data.get("mime_type", "text/plain")
|
||||
return not mime_type.startswith("text/")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the file content."""
|
||||
try:
|
||||
if self.is_binary:
|
||||
return await anyio.to_thread.run_sync(self.path.read_bytes)
|
||||
return await anyio.to_thread.run_sync(self.path.read_text)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading file {self.path}: {e}")
|
||||
|
||||
|
||||
class HttpResource(Resource):
|
||||
"""A resource that reads from an HTTP endpoint."""
|
||||
|
||||
url: str = Field(description="URL to fetch content from")
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the HTTP content."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
class DirectoryResource(Resource):
|
||||
"""A resource that lists files in a directory."""
|
||||
|
||||
path: Path = Field(description="Path to the directory")
|
||||
recursive: bool = Field(
|
||||
default=False, description="Whether to list files recursively"
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
default=None, description="Optional glob pattern to filter files"
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path:
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
def list_files(self) -> list[Path]:
|
||||
"""List files in the directory."""
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Directory not found: {self.path}")
|
||||
if not self.path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {self.path}")
|
||||
|
||||
try:
|
||||
if self.pattern:
|
||||
return (
|
||||
list(self.path.glob(self.pattern))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob(self.pattern))
|
||||
)
|
||||
return (
|
||||
list(self.path.glob("*"))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob("*"))
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||
|
||||
async def read(self) -> str: # Always returns JSON string
|
||||
"""Read the directory listing."""
|
||||
try:
|
||||
files = await anyio.to_thread.run_sync(self.list_files)
|
||||
file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()]
|
||||
return json.dumps({"files": file_list}, indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading directory {self.path}: {e}")
|
||||
713
.venv/lib/python3.10/site-packages/mcp/server/fastmcp/server.py
Normal file
713
.venv/lib/python3.10/site-packages/mcp/server/fastmcp/server.py
Normal file
@@ -0,0 +1,713 @@
|
||||
"""FastMCP - A more ergonomic interface for MCP servers."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
asynccontextmanager,
|
||||
)
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, Literal
|
||||
|
||||
import anyio
|
||||
import pydantic_core
|
||||
import uvicorn
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ResourceError
|
||||
from mcp.server.fastmcp.prompts import Prompt, PromptManager
|
||||
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
|
||||
from mcp.server.fastmcp.tools import ToolManager
|
||||
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.lowlevel.server import LifespanResultT
|
||||
from mcp.server.lowlevel.server import Server as MCPServer
|
||||
from mcp.server.lowlevel.server import lifespan as default_lifespan
|
||||
from mcp.server.session import ServerSession, ServerSessionT
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.shared.context import LifespanContextT, RequestContext
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
EmbeddedResource,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import Prompt as MCPPrompt
|
||||
from mcp.types import PromptArgument as MCPPromptArgument
|
||||
from mcp.types import Resource as MCPResource
|
||||
from mcp.types import ResourceTemplate as MCPResourceTemplate
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
"""FastMCP server settings.
|
||||
|
||||
All settings can be configured via environment variables with the prefix FASTMCP_.
|
||||
For example, FASTMCP_DEBUG=true will set debug=True.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="FASTMCP_",
|
||||
env_file=".env",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Server settings
|
||||
debug: bool = False
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
|
||||
# HTTP settings
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
sse_path: str = "/sse"
|
||||
message_path: str = "/messages/"
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
|
||||
# tool settings
|
||||
warn_on_duplicate_tools: bool = True
|
||||
|
||||
# prompt settings
|
||||
warn_on_duplicate_prompts: bool = True
|
||||
|
||||
dependencies: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
|
||||
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
yield context
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class FastMCP:
|
||||
def __init__(
|
||||
self, name: str | None = None, instructions: str | None = None, **settings: Any
|
||||
):
|
||||
self.settings = Settings(**settings)
|
||||
|
||||
self._mcp_server = MCPServer(
|
||||
name=name or "FastMCP",
|
||||
instructions=instructions,
|
||||
lifespan=lifespan_wrapper(self, self.settings.lifespan)
|
||||
if self.settings.lifespan
|
||||
else default_lifespan,
|
||||
)
|
||||
self._tool_manager = ToolManager(
|
||||
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
self._resource_manager = ResourceManager(
|
||||
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
|
||||
)
|
||||
self._prompt_manager = PromptManager(
|
||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
||||
)
|
||||
self.dependencies = self.settings.dependencies
|
||||
|
||||
# Set up MCP protocol handlers
|
||||
self._setup_handlers()
|
||||
|
||||
# Configure logging
|
||||
configure_logging(self.settings.log_level)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._mcp_server.name
|
||||
|
||||
@property
|
||||
def instructions(self) -> str | None:
|
||||
return self._mcp_server.instructions
|
||||
|
||||
def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None:
|
||||
"""Run the FastMCP server. Note this is a synchronous function.
|
||||
|
||||
Args:
|
||||
transport: Transport protocol to use ("stdio" or "sse")
|
||||
"""
|
||||
TRANSPORTS = Literal["stdio", "sse"]
|
||||
if transport not in TRANSPORTS.__args__: # type: ignore
|
||||
raise ValueError(f"Unknown transport: {transport}")
|
||||
|
||||
if transport == "stdio":
|
||||
anyio.run(self.run_stdio_async)
|
||||
else: # transport == "sse"
|
||||
anyio.run(self.run_sse_async)
|
||||
|
||||
def _setup_handlers(self) -> None:
|
||||
"""Set up core MCP protocol handlers."""
|
||||
self._mcp_server.list_tools()(self.list_tools)
|
||||
self._mcp_server.call_tool()(self.call_tool)
|
||||
self._mcp_server.list_resources()(self.list_resources)
|
||||
self._mcp_server.read_resource()(self.read_resource)
|
||||
self._mcp_server.list_prompts()(self.list_prompts)
|
||||
self._mcp_server.get_prompt()(self.get_prompt)
|
||||
self._mcp_server.list_resource_templates()(self.list_resource_templates)
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List all available tools."""
|
||||
tools = self._tool_manager.list_tools()
|
||||
return [
|
||||
MCPTool(
|
||||
name=info.name,
|
||||
description=info.description,
|
||||
inputSchema=info.parameters,
|
||||
)
|
||||
for info in tools
|
||||
]
|
||||
|
||||
def get_context(self) -> Context[ServerSession, object]:
|
||||
"""
|
||||
Returns a Context object. Note that the context will only be valid
|
||||
during a request; outside a request, most methods will error.
|
||||
"""
|
||||
try:
|
||||
request_context = self._mcp_server.request_context
|
||||
except LookupError:
|
||||
request_context = None
|
||||
return Context(request_context=request_context, fastmcp=self)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any]
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||
converted_result = _convert_to_content(result)
|
||||
return converted_result
|
||||
|
||||
async def list_resources(self) -> list[MCPResource]:
|
||||
"""List all available resources."""
|
||||
|
||||
resources = self._resource_manager.list_resources()
|
||||
return [
|
||||
MCPResource(
|
||||
uri=resource.uri,
|
||||
name=resource.name or "",
|
||||
description=resource.description,
|
||||
mimeType=resource.mime_type,
|
||||
)
|
||||
for resource in resources
|
||||
]
|
||||
|
||||
async def list_resource_templates(self) -> list[MCPResourceTemplate]:
|
||||
templates = self._resource_manager.list_templates()
|
||||
return [
|
||||
MCPResourceTemplate(
|
||||
uriTemplate=template.uri_template,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
)
|
||||
for template in templates
|
||||
]
|
||||
|
||||
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
|
||||
"""Read a resource by URI."""
|
||||
|
||||
resource = await self._resource_manager.get_resource(uri)
|
||||
if not resource:
|
||||
raise ResourceError(f"Unknown resource: {uri}")
|
||||
|
||||
try:
|
||||
content = await resource.read()
|
||||
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading resource {uri}: {e}")
|
||||
raise ResourceError(str(e))
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: AnyFunction,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
"""Add a tool to the server.
|
||||
|
||||
The tool function can optionally request a Context object by adding a parameter
|
||||
with the Context type annotation. See the @tool decorator for examples.
|
||||
|
||||
Args:
|
||||
fn: The function to register as a tool
|
||||
name: Optional name for the tool (defaults to function name)
|
||||
description: Optional description of what the tool does
|
||||
"""
|
||||
self._tool_manager.add_tool(fn, name=name, description=description)
|
||||
|
||||
def tool(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a tool.
|
||||
|
||||
Tools can optionally request a Context object by adding a parameter with the
|
||||
Context type annotation. The context provides access to MCP capabilities like
|
||||
logging, progress reporting, and resource access.
|
||||
|
||||
Args:
|
||||
name: Optional name for the tool (defaults to function name)
|
||||
description: Optional description of what the tool does
|
||||
|
||||
Example:
|
||||
@server.tool()
|
||||
def my_tool(x: int) -> str:
|
||||
return str(x)
|
||||
|
||||
@server.tool()
|
||||
def tool_with_context(x: int, ctx: Context) -> str:
|
||||
ctx.info(f"Processing {x}")
|
||||
return str(x)
|
||||
|
||||
@server.tool()
|
||||
async def async_tool(x: int, context: Context) -> str:
|
||||
await context.report_progress(50, 100)
|
||||
return str(x)
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @tool decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @tool() instead of @tool"
|
||||
)
|
||||
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
self.add_tool(fn, name=name, description=description)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def add_resource(self, resource: Resource) -> None:
|
||||
"""Add a resource to the server.
|
||||
|
||||
Args:
|
||||
resource: A Resource instance to add
|
||||
"""
|
||||
self._resource_manager.add_resource(resource)
|
||||
|
||||
def resource(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a function as a resource.
|
||||
|
||||
The function will be called when the resource is read to generate its content.
|
||||
The function can return:
|
||||
- str for text content
|
||||
- bytes for binary content
|
||||
- other types will be converted to JSON
|
||||
|
||||
If the URI contains parameters (e.g. "resource://{param}") or the function
|
||||
has parameters, it will be registered as a template resource.
|
||||
|
||||
Args:
|
||||
uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}")
|
||||
name: Optional name for the resource
|
||||
description: Optional description of the resource
|
||||
mime_type: Optional MIME type for the resource
|
||||
|
||||
Example:
|
||||
@server.resource("resource://my-resource")
|
||||
def get_data() -> str:
|
||||
return "Hello, world!"
|
||||
|
||||
@server.resource("resource://my-resource")
|
||||
async get_data() -> str:
|
||||
data = await fetch_data()
|
||||
return f"Hello, world! {data}"
|
||||
|
||||
@server.resource("resource://{city}/weather")
|
||||
def get_weather(city: str) -> str:
|
||||
return f"Weather for {city}"
|
||||
|
||||
@server.resource("resource://{city}/weather")
|
||||
async def get_weather(city: str) -> str:
|
||||
data = await fetch_weather(city)
|
||||
return f"Weather for {city}: {data}"
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(uri):
|
||||
raise TypeError(
|
||||
"The @resource decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @resource('uri') instead of @resource"
|
||||
)
|
||||
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
# Check if this should be a template
|
||||
has_uri_params = "{" in uri and "}" in uri
|
||||
has_func_params = bool(inspect.signature(fn).parameters)
|
||||
|
||||
if has_uri_params or has_func_params:
|
||||
# Validate that URI params match function params
|
||||
uri_params = set(re.findall(r"{(\w+)}", uri))
|
||||
func_params = set(inspect.signature(fn).parameters.keys())
|
||||
|
||||
if uri_params != func_params:
|
||||
raise ValueError(
|
||||
f"Mismatch between URI parameters {uri_params} "
|
||||
f"and function parameters {func_params}"
|
||||
)
|
||||
|
||||
# Register as template
|
||||
self._resource_manager.add_template(
|
||||
fn=fn,
|
||||
uri_template=uri,
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type or "text/plain",
|
||||
)
|
||||
else:
|
||||
# Register as regular resource
|
||||
resource = FunctionResource(
|
||||
uri=AnyUrl(uri),
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type or "text/plain",
|
||||
fn=fn,
|
||||
)
|
||||
self.add_resource(resource)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def add_prompt(self, prompt: Prompt) -> None:
|
||||
"""Add a prompt to the server.
|
||||
|
||||
Args:
|
||||
prompt: A Prompt instance to add
|
||||
"""
|
||||
self._prompt_manager.add_prompt(prompt)
|
||||
|
||||
def prompt(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a prompt.
|
||||
|
||||
Args:
|
||||
name: Optional name for the prompt (defaults to function name)
|
||||
description: Optional description of what the prompt does
|
||||
|
||||
Example:
|
||||
@server.prompt()
|
||||
def analyze_table(table_name: str) -> list[Message]:
|
||||
schema = read_table_schema(table_name)
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Analyze this schema:\n{schema}"
|
||||
}
|
||||
]
|
||||
|
||||
@server.prompt()
|
||||
async def analyze_file(path: str) -> list[Message]:
|
||||
content = await read_file(path)
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "resource",
|
||||
"resource": {
|
||||
"uri": f"file://{path}",
|
||||
"text": content
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @prompt decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @prompt() instead of @prompt"
|
||||
)
|
||||
|
||||
def decorator(func: AnyFunction) -> AnyFunction:
|
||||
prompt = Prompt.from_function(func, name=name, description=description)
|
||||
self.add_prompt(prompt)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run_stdio_async(self) -> None:
|
||||
"""Run the server using stdio transport."""
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await self._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
async def run_sse_async(self) -> None:
|
||||
"""Run the server using SSE transport."""
|
||||
starlette_app = self.sse_app()
|
||||
|
||||
config = uvicorn.Config(
|
||||
starlette_app,
|
||||
host=self.settings.host,
|
||||
port=self.settings.port,
|
||||
log_level=self.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def sse_app(self) -> Starlette:
|
||||
"""Return an instance of the SSE server app."""
|
||||
sse = SseServerTransport(self.settings.message_path)
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # type: ignore[reportPrivateUsage]
|
||||
) as streams:
|
||||
await self._mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=self.settings.debug,
|
||||
routes=[
|
||||
Route(self.settings.sse_path, endpoint=handle_sse),
|
||||
Mount(self.settings.message_path, app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> list[MCPPrompt]:
|
||||
"""List all available prompts."""
|
||||
prompts = self._prompt_manager.list_prompts()
|
||||
return [
|
||||
MCPPrompt(
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
arguments=[
|
||||
MCPPromptArgument(
|
||||
name=arg.name,
|
||||
description=arg.description,
|
||||
required=arg.required,
|
||||
)
|
||||
for arg in (prompt.arguments or [])
|
||||
],
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> GetPromptResult:
|
||||
"""Get a prompt by name with arguments."""
|
||||
try:
|
||||
messages = await self._prompt_manager.render_prompt(name, arguments)
|
||||
|
||||
return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prompt {name}: {e}")
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
"""Convert a result to a sequence of content objects."""
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
return [result.to_image_content()]
|
||||
|
||||
if isinstance(result, list | tuple):
|
||||
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
|
||||
|
||||
if not isinstance(result, str):
|
||||
try:
|
||||
result = json.dumps(pydantic_core.to_jsonable_python(result))
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
|
||||
|
||||
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
|
||||
"""Context object providing access to MCP capabilities.
|
||||
|
||||
This provides a cleaner interface to MCP's RequestContext functionality.
|
||||
It gets injected into tool and resource functions that request it via type hints.
|
||||
|
||||
To use context in a tool function, add a parameter with the Context type annotation:
|
||||
|
||||
```python
|
||||
@server.tool()
|
||||
def my_tool(x: int, ctx: Context) -> str:
|
||||
# Log messages to the client
|
||||
ctx.info(f"Processing {x}")
|
||||
ctx.debug("Debug info")
|
||||
ctx.warning("Warning message")
|
||||
ctx.error("Error message")
|
||||
|
||||
# Report progress
|
||||
ctx.report_progress(50, 100)
|
||||
|
||||
# Access resources
|
||||
data = ctx.read_resource("resource://data")
|
||||
|
||||
# Get request info
|
||||
request_id = ctx.request_id
|
||||
client_id = ctx.client_id
|
||||
|
||||
return str(x)
|
||||
```
|
||||
|
||||
The context parameter name can be anything as long as it's annotated with Context.
|
||||
The context is optional - tools that don't need it can omit the parameter.
|
||||
"""
|
||||
|
||||
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
|
||||
_fastmcp: FastMCP | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
|
||||
fastmcp: FastMCP | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._request_context = request_context
|
||||
self._fastmcp = fastmcp
|
||||
|
||||
@property
|
||||
def fastmcp(self) -> FastMCP:
|
||||
"""Access to the FastMCP server."""
|
||||
if self._fastmcp is None:
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._fastmcp
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
|
||||
"""Access to the underlying request context."""
|
||||
if self._request_context is None:
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
progress: Current progress value e.g. 24
|
||||
total: Optional total value e.g. 100
|
||||
"""
|
||||
|
||||
progress_token = (
|
||||
self.request_context.meta.progressToken
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
|
||||
if progress_token is None:
|
||||
return
|
||||
|
||||
await self.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=progress, total=total
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
|
||||
"""Read a resource by URI.
|
||||
|
||||
Args:
|
||||
uri: Resource URI to read
|
||||
|
||||
Returns:
|
||||
The resource content as either text or bytes
|
||||
"""
|
||||
assert (
|
||||
self._fastmcp is not None
|
||||
), "Context is not available outside of a request"
|
||||
return await self._fastmcp.read_resource(uri)
|
||||
|
||||
async def log(
|
||||
self,
|
||||
level: Literal["debug", "info", "warning", "error"],
|
||||
message: str,
|
||||
*,
|
||||
logger_name: str | None = None,
|
||||
) -> None:
|
||||
"""Send a log message to the client.
|
||||
|
||||
Args:
|
||||
level: Log level (debug, info, warning, error)
|
||||
message: Log message
|
||||
logger_name: Optional logger name
|
||||
**extra: Additional structured data to include
|
||||
"""
|
||||
await self.request_context.session.send_log_message(
|
||||
level=level, data=message, logger=logger_name
|
||||
)
|
||||
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
"""Get the client ID if available."""
|
||||
return (
|
||||
getattr(self.request_context.meta, "client_id", None)
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
"""Get the unique ID for this request."""
|
||||
return str(self.request_context.request_id)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Access to the underlying session for advanced usage."""
|
||||
return self.request_context.session
|
||||
|
||||
# Convenience methods for common log levels
|
||||
async def debug(self, message: str, **extra: Any) -> None:
|
||||
"""Send a debug log message."""
|
||||
await self.log("debug", message, **extra)
|
||||
|
||||
async def info(self, message: str, **extra: Any) -> None:
|
||||
"""Send an info log message."""
|
||||
await self.log("info", message, **extra)
|
||||
|
||||
async def warning(self, message: str, **extra: Any) -> None:
|
||||
"""Send a warning log message."""
|
||||
await self.log("warning", message, **extra)
|
||||
|
||||
async def error(self, message: str, **extra: Any) -> None:
|
||||
"""Send an error log message."""
|
||||
await self.log("error", message, **extra)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .base import Tool
|
||||
from .tool_manager import ToolManager
|
||||
|
||||
__all__ = ["Tool", "ToolManager"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Internal tool registration info."""
|
||||
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
name: str = Field(description="Name of the tool")
|
||||
description: str = Field(description="Description of what the tool does")
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
||||
fn_metadata: FuncMetadata = Field(
|
||||
description="Metadata about the function including a pydantic model for tool"
|
||||
" arguments"
|
||||
)
|
||||
is_async: bool = Field(description="Whether the tool is async")
|
||||
context_kwarg: str | None = Field(
|
||||
None, description="Name of the kwarg that should receive context"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
) -> Tool:
|
||||
"""Create a Tool from a function."""
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
func_doc = description or fn.__doc__ or ""
|
||||
is_async = inspect.iscoroutinefunction(fn)
|
||||
|
||||
if context_kwarg is None:
|
||||
sig = inspect.signature(fn)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is Context:
|
||||
context_kwarg = param_name
|
||||
break
|
||||
|
||||
func_arg_metadata = func_metadata(
|
||||
fn,
|
||||
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
||||
)
|
||||
parameters = func_arg_metadata.arg_model.model_json_schema()
|
||||
|
||||
return cls(
|
||||
fn=fn,
|
||||
name=func_name,
|
||||
description=func_doc,
|
||||
parameters=parameters,
|
||||
fn_metadata=func_arg_metadata,
|
||||
is_async=is_async,
|
||||
context_kwarg=context_kwarg,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
||||
) -> Any:
|
||||
"""Run the tool with arguments."""
|
||||
try:
|
||||
return await self.fn_metadata.call_fn_with_arg_validation(
|
||||
self.fn,
|
||||
self.is_async,
|
||||
arguments,
|
||||
{self.context_kwarg: context}
|
||||
if self.context_kwarg is not None
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.tools.base import Tool
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
from mcp.shared.context import LifespanContextT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manages FastMCP tools."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_tools: bool = True):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
self.warn_on_duplicate_tools = warn_on_duplicate_tools
|
||||
|
||||
def get_tool(self, name: str) -> Tool | None:
|
||||
"""Get tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> Tool:
|
||||
"""Add a tool to the server."""
|
||||
tool = Tool.from_function(fn, name=name, description=description)
|
||||
existing = self._tools.get(tool.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_tools:
|
||||
logger.warning(f"Tool already exists: {tool.name}")
|
||||
return existing
|
||||
self._tools[tool.name] = tool
|
||||
return tool
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT] | None = None,
|
||||
) -> Any:
|
||||
"""Call a tool by name with arguments."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
raise ToolError(f"Unknown tool: {name}")
|
||||
|
||||
return await tool.run(arguments, context=context)
|
||||
@@ -0,0 +1 @@
|
||||
"""FastMCP utility modules."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,214 @@
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
ForwardRef,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model
|
||||
from pydantic._internal._typing_extra import eval_type_backport
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from mcp.server.fastmcp.exceptions import InvalidSignature
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ArgModelBase(BaseModel):
|
||||
"""A model representing the arguments to a function."""
|
||||
|
||||
def model_dump_one_level(self) -> dict[str, Any]:
|
||||
"""Return a dict of the model's fields, one level deep.
|
||||
|
||||
That is, sub-models etc are not dumped - they are kept as pydantic models.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
for field_name in self.model_fields.keys():
|
||||
kwargs[field_name] = getattr(self, field_name)
|
||||
return kwargs
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class FuncMetadata(BaseModel):
|
||||
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
|
||||
# We can add things in the future like
|
||||
# - Maybe some args are excluded from attempting to parse from JSON
|
||||
# - Maybe some args are special (like context) for dependency injection
|
||||
|
||||
async def call_fn_with_arg_validation(
|
||||
self,
|
||||
fn: Callable[..., Any] | Awaitable[Any],
|
||||
fn_is_async: bool,
|
||||
arguments_to_validate: dict[str, Any],
|
||||
arguments_to_pass_directly: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Call the given function with arguments validated and injected.
|
||||
|
||||
Arguments are first attempted to be parsed from JSON, then validated against
|
||||
the argument model, before being passed to the function.
|
||||
"""
|
||||
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
|
||||
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
|
||||
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
|
||||
|
||||
arguments_parsed_dict |= arguments_to_pass_directly or {}
|
||||
|
||||
if fn_is_async:
|
||||
if isinstance(fn, Awaitable):
|
||||
return await fn
|
||||
return await fn(**arguments_parsed_dict)
|
||||
if isinstance(fn, Callable):
|
||||
return fn(**arguments_parsed_dict)
|
||||
raise TypeError("fn must be either Callable or Awaitable")
|
||||
|
||||
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Pre-parse data from JSON.
|
||||
|
||||
Return a dict with same keys as input but with values parsed from JSON
|
||||
if appropriate.
|
||||
|
||||
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
|
||||
a string rather than an actual list. Claude desktop is prone to this - in fact
|
||||
it seems incapable of NOT doing this. For sub-models, it tends to pass
|
||||
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
||||
"""
|
||||
new_data = data.copy() # Shallow copy
|
||||
for field_name, _field_info in self.arg_model.model_fields.items():
|
||||
if field_name not in data.keys():
|
||||
continue
|
||||
if isinstance(data[field_name], str):
|
||||
try:
|
||||
pre_parsed = json.loads(data[field_name])
|
||||
except json.JSONDecodeError:
|
||||
continue # Not JSON - skip
|
||||
if isinstance(pre_parsed, str | int | float):
|
||||
# This is likely that the raw value is e.g. `"hello"` which we
|
||||
# Should really be parsed as '"hello"' in Python - but if we parse
|
||||
# it as JSON it'll turn into just 'hello'. So we skip it.
|
||||
continue
|
||||
new_data[field_name] = pre_parsed
|
||||
assert new_data.keys() == data.keys()
|
||||
return new_data
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
def func_metadata(
|
||||
func: Callable[..., Any], skip_names: Sequence[str] = ()
|
||||
) -> FuncMetadata:
|
||||
"""Given a function, return metadata including a pydantic model representing its
|
||||
signature.
|
||||
|
||||
The use case for this is
|
||||
```
|
||||
meta = func_to_pyd(func)
|
||||
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
|
||||
return func(**validated_args.model_dump_one_level())
|
||||
```
|
||||
|
||||
**critically** it also provides pre-parse helper to attempt to parse things from
|
||||
JSON.
|
||||
|
||||
Args:
|
||||
func: The function to convert to a pydantic model
|
||||
skip_names: A list of parameter names to skip. These will not be included in
|
||||
the model.
|
||||
Returns:
|
||||
A pydantic model representing the function's signature.
|
||||
"""
|
||||
sig = _get_typed_signature(func)
|
||||
params = sig.parameters
|
||||
dynamic_pydantic_model_params: dict[str, Any] = {}
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
for param in params.values():
|
||||
if param.name.startswith("_"):
|
||||
raise InvalidSignature(
|
||||
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
|
||||
)
|
||||
if param.name in skip_names:
|
||||
continue
|
||||
annotation = param.annotation
|
||||
|
||||
# `x: None` / `x: None = None`
|
||||
if annotation is None:
|
||||
annotation = Annotated[
|
||||
None,
|
||||
Field(
|
||||
default=param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined
|
||||
),
|
||||
]
|
||||
|
||||
# Untyped field
|
||||
if annotation is inspect.Parameter.empty:
|
||||
annotation = Annotated[
|
||||
Any,
|
||||
Field(),
|
||||
# 🤷
|
||||
WithJsonSchema({"title": param.name, "type": "string"}),
|
||||
]
|
||||
|
||||
field_info = FieldInfo.from_annotated_attribute(
|
||||
_get_typed_annotation(annotation, globalns),
|
||||
param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined,
|
||||
)
|
||||
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
||||
continue
|
||||
|
||||
arguments_model = create_model(
|
||||
f"{func.__name__}Arguments",
|
||||
**dynamic_pydantic_model_params,
|
||||
__base__=ArgModelBase,
|
||||
)
|
||||
resp = FuncMetadata(arg_model=arguments_model)
|
||||
return resp
|
||||
|
||||
|
||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
def try_eval_type(
|
||||
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
||||
) -> tuple[Any, bool]:
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
return value, False
|
||||
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation, status = try_eval_type(annotation, globalns, globalns)
|
||||
|
||||
# This check and raise could perhaps be skipped, and we (FastMCP) just call
|
||||
# model_rebuild right before using it 🤷
|
||||
if status is False:
|
||||
raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get function signature while evaluating forward references"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=_get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Logging utilities for FastMCP."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger nested under MCPnamespace.
|
||||
|
||||
Args:
|
||||
name: the name of the logger, which will be prefixed with 'FastMCP.'
|
||||
|
||||
Returns:
|
||||
a configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
|
||||
) -> None:
|
||||
"""Configure logging for MCP.
|
||||
|
||||
Args:
|
||||
level: the log level to use
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not handlers:
|
||||
handlers.append(logging.StreamHandler())
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(message)s",
|
||||
handlers=handlers,
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Common types used across FastMCP."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.types import ImageContent
|
||||
|
||||
|
||||
class Image:
|
||||
"""Helper class for returning images from tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
data: bytes | None = None,
|
||||
format: str | None = None,
|
||||
):
|
||||
if path is None and data is None:
|
||||
raise ValueError("Either path or data must be provided")
|
||||
if path is not None and data is not None:
|
||||
raise ValueError("Only one of path or data can be provided")
|
||||
|
||||
self.path = Path(path) if path else None
|
||||
self.data = data
|
||||
self._format = format
|
||||
self._mime_type = self._get_mime_type()
|
||||
|
||||
def _get_mime_type(self) -> str:
|
||||
"""Get MIME type from format or guess from file extension."""
|
||||
if self._format:
|
||||
return f"image/{self._format.lower()}"
|
||||
|
||||
if self.path:
|
||||
suffix = self.path.suffix.lower()
|
||||
return {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
return "image/png" # default for raw binary data
|
||||
|
||||
def to_image_content(self) -> ImageContent:
|
||||
"""Convert to MCP ImageContent."""
|
||||
if self.path:
|
||||
with open(self.path, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode()
|
||||
elif self.data is not None:
|
||||
data = base64.b64encode(self.data).decode()
|
||||
else:
|
||||
raise ValueError("No image data available")
|
||||
|
||||
return ImageContent(type="image", data=data, mimeType=self._mime_type)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .server import NotificationOptions, Server
|
||||
|
||||
__all__ = ["Server", "NotificationOptions"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReadResourceContents:
|
||||
"""Contents returned from a read_resource call."""
|
||||
|
||||
content: str | bytes
|
||||
mime_type: str | None = None
|
||||
590
.venv/lib/python3.10/site-packages/mcp/server/lowlevel/server.py
Normal file
590
.venv/lib/python3.10/site-packages/mcp/server/lowlevel/server.py
Normal file
@@ -0,0 +1,590 @@
|
||||
"""
|
||||
MCP Server Module
|
||||
|
||||
This module provides a framework for creating an MCP (Model Context Protocol) server.
|
||||
It allows you to easily define and handle various types of requests and notifications
|
||||
in an asynchronous manner.
|
||||
|
||||
Usage:
|
||||
1. Create a Server instance:
|
||||
server = Server("your_server_name")
|
||||
|
||||
2. Define request handlers using decorators:
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
# Implementation
|
||||
|
||||
@server.get_prompt()
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
# Implementation
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
# Implementation
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
# Implementation
|
||||
|
||||
@server.list_resource_templates()
|
||||
async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
|
||||
# Implementation
|
||||
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
4. Run the server:
|
||||
async def main():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="your_server_name",
|
||||
server_version="your_version",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The Server class provides methods to register handlers for various MCP requests and
|
||||
notifications. It automatically manages the request context and handles incoming
|
||||
messages from the client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server as stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LifespanResultT = TypeVar("LifespanResultT")
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
def __init__(
|
||||
self,
|
||||
prompts_changed: bool = False,
|
||||
resources_changed: bool = False,
|
||||
tools_changed: bool = False,
|
||||
):
|
||||
self.prompts_changed = prompts_changed
|
||||
self.resources_changed = resources_changed
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
server: The server instance this lifespan is managing
|
||||
|
||||
Returns:
|
||||
An empty context object
|
||||
"""
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
lifespan: Callable[
|
||||
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
types.PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
self.notification_options = NotificationOptions()
|
||||
logger.debug(f"Initializing server '{name}'")
|
||||
|
||||
def create_initialization_options(
|
||||
self,
|
||||
notification_options: NotificationOptions | None = None,
|
||||
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
|
||||
) -> InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "unknown"
|
||||
|
||||
return InitializationOptions(
|
||||
server_name=self.name,
|
||||
server_version=self.version if self.version else pkg_version("mcp"),
|
||||
capabilities=self.get_capabilities(
|
||||
notification_options or NotificationOptions(),
|
||||
experimental_capabilities or {},
|
||||
),
|
||||
instructions=self.instructions,
|
||||
)
|
||||
|
||||
def get_capabilities(
|
||||
self,
|
||||
notification_options: NotificationOptions,
|
||||
experimental_capabilities: dict[str, dict[str, Any]],
|
||||
) -> types.ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
prompts_capability = None
|
||||
resources_capability = None
|
||||
tools_capability = None
|
||||
logging_capability = None
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if types.ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = types.PromptsCapability(
|
||||
listChanged=notification_options.prompts_changed
|
||||
)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if types.ListResourcesRequest in self.request_handlers:
|
||||
resources_capability = types.ResourcesCapability(
|
||||
subscribe=False, listChanged=notification_options.resources_changed
|
||||
)
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
tools_capability = types.ToolsCapability(
|
||||
listChanged=notification_options.tools_changed
|
||||
)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if types.SetLevelRequest in self.request_handlers:
|
||||
logging_capability = types.LoggingCapability()
|
||||
|
||||
return types.ServerCapabilities(
|
||||
prompts=prompts_capability,
|
||||
resources=resources_capability,
|
||||
tools=tools_capability,
|
||||
logging=logging_capability,
|
||||
experimental=experimental_capabilities,
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[types.ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: types.GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(prompt_get)
|
||||
|
||||
self.request_handlers[types.GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourcesResult(resources=resources)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resource_templates(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
|
||||
logger.debug("Registering handler for ListResourceTemplatesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
templates = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourceTemplatesResult(resourceTemplates=templates)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: types.ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
|
||||
def create_content(data: str | bytes, mime_type: str | None):
|
||||
match data:
|
||||
case str() as data:
|
||||
return types.TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=data,
|
||||
mimeType=mime_type or "text/plain",
|
||||
)
|
||||
case bytes() as data:
|
||||
import base64
|
||||
|
||||
return types.BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.b64encode(data).decode(),
|
||||
mimeType=mime_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
match result:
|
||||
case str() | bytes() as data:
|
||||
warnings.warn(
|
||||
"Returning str or bytes from read_resource is deprecated. "
|
||||
"Use Iterable[ReadResourceContents] instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
content = create_content(data, None)
|
||||
case Iterable() as contents:
|
||||
contents_list = [
|
||||
create_content(content_item.content, content_item.mime_type)
|
||||
for content_item in contents
|
||||
]
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=contents_list,
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected return type from read_resource: {type(result)}"
|
||||
)
|
||||
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: types.SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: types.SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: types.UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
|
||||
logger.debug("Registering handler for ListToolsRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
tools = await func()
|
||||
return types.ServerResult(types.ListToolsResult(tools=tools))
|
||||
|
||||
self.request_handlers[types.ListToolsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Iterable[
|
||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
||||
]
|
||||
],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: types.CallToolRequest):
|
||||
try:
|
||||
results = await func(req.params.name, (req.params.arguments or {}))
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(content=list(results), isError=False)
|
||||
)
|
||||
except Exception as e:
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[
|
||||
types.PromptReference | types.ResourceReference,
|
||||
types.CompletionArgument,
|
||||
],
|
||||
Awaitable[types.Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: types.CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return types.ServerResult(
|
||||
types.CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else types.Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
initialization_options: InitializationOptions,
|
||||
# When False, exceptions are returned as messages to the client.
|
||||
# When True, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(read_stream, write_stream, initialization_options)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
tg.start_soon(
|
||||
self._handle_message,
|
||||
message,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# TODO(Marcelo): We should be checking if message is Exception here.
|
||||
match message: # type: ignore[reportMatchNotExhaustive]
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, lifespan_context, raise_exceptions
|
||||
)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult],
|
||||
req: Any,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool,
|
||||
):
|
||||
logger.info(f"Processing request of type {type(req).__name__}")
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(f"Dispatching request of type {type(req).__name__}")
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
except McpError as err:
|
||||
response = err.error
|
||||
except Exception as err:
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = types.ErrorData(code=0, message=str(err), data=None)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None:
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
|
||||
async def _handle_notification(self, notify: Any):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type " f"{type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(f"Uncaught exception in notification handler: " f"{err}")
|
||||
|
||||
|
||||
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
17
.venv/lib/python3.10/site-packages/mcp/server/models.py
Normal file
17
.venv/lib/python3.10/site-packages/mcp/server/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
This module provides simpler types to use with the server for managing prompts
|
||||
and tools.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.types import (
|
||||
ServerCapabilities,
|
||||
)
|
||||
|
||||
|
||||
class InitializationOptions(BaseModel):
|
||||
server_name: str
|
||||
server_version: str
|
||||
capabilities: ServerCapabilities
|
||||
instructions: str | None = None
|
||||
317
.venv/lib/python3.10/site-packages/mcp/server/session.py
Normal file
317
.venv/lib/python3.10/site-packages/mcp/server/session.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
ServerSession Module
|
||||
|
||||
This module provides the ServerSession class, which manages communication between the
|
||||
server and client in the MCP (Model Context Protocol) framework. It is most commonly
|
||||
used in MCP servers to interact with the client.
|
||||
|
||||
Common usage pattern:
|
||||
```
|
||||
server = Server(name)
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any:
|
||||
# Check client capabilities before proceeding
|
||||
if ctx.session.check_client_capability(
|
||||
types.ClientCapabilities(experimental={"advanced_tools": dict()})
|
||||
):
|
||||
# Perform advanced tool operations
|
||||
result = await perform_advanced_tool_operation(arguments)
|
||||
else:
|
||||
# Fall back to basic tool operations
|
||||
result = await perform_basic_tool_operation(arguments)
|
||||
|
||||
return result
|
||||
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
|
||||
# Access session for any necessary checks or operations
|
||||
if ctx.session.client_params:
|
||||
# Customize prompts based on client initialization parameters
|
||||
return generate_custom_prompts(ctx.session.client_params)
|
||||
else:
|
||||
return default_prompts
|
||||
```
|
||||
|
||||
The ServerSession class is typically used internally by the Server class and should not
|
||||
be instantiated directly by users of the MCP framework.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
NotInitialized = 1
|
||||
Initializing = 2
|
||||
Initialized = 3
|
||||
|
||||
|
||||
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
||||
|
||||
ServerRequestResponder = (
|
||||
RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception
|
||||
)
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
types.ServerResult,
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
_client_params: types.InitializeRequestParams | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
init_options: InitializationOptions,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream, write_stream, types.ClientRequest, types.ClientNotification
|
||||
)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
self._init_options = init_options
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[ServerRequestResponder](0)
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_reader.aclose()
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_writer.aclose()
|
||||
)
|
||||
|
||||
@property
|
||||
def client_params(self) -> types.InitializeRequestParams | None:
|
||||
return self._client_params
|
||||
|
||||
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
|
||||
"""Check if the client supports a specific capability."""
|
||||
if self._client_params is None:
|
||||
return False
|
||||
|
||||
# Get client capabilities from initialization params
|
||||
client_caps = self._client_params.capabilities
|
||||
|
||||
# Check each specified capability in the passed in capability object
|
||||
if capability.roots is not None:
|
||||
if client_caps.roots is None:
|
||||
return False
|
||||
if capability.roots.listChanged and not client_caps.roots.listChanged:
|
||||
return False
|
||||
|
||||
if capability.sampling is not None:
|
||||
if client_caps.sampling is None:
|
||||
return False
|
||||
|
||||
if capability.experimental is not None:
|
||||
if client_caps.experimental is None:
|
||||
return False
|
||||
# Check each experimental capability
|
||||
for exp_key, exp_value in capability.experimental.items():
|
||||
if (
|
||||
exp_key not in client_caps.experimental
|
||||
or client_caps.experimental[exp_key] != exp_value
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case types.InitializeRequest(params=params):
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
self._client_params = params
|
||||
with responder:
|
||||
await responder.respond(
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=types.Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
),
|
||||
instructions=self._init_options.instructions,
|
||||
)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
|
||||
async def _received_notification(
|
||||
self, notification: types.ClientNotification
|
||||
) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case types.InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: types.LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=types.LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
"""Send a resource updated notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ResourceUpdatedNotification(
|
||||
method="notifications/resources/updated",
|
||||
params=types.ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[types.SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: types.IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=types.CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
maxTokens=max_tokens,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
modelPreferences=model_preferences,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateMessageResult,
|
||||
)
|
||||
|
||||
async def list_roots(self) -> types.ListRootsResult:
|
||||
"""Send a roots/list request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
types.ListRootsRequest(
|
||||
method="roots/list",
|
||||
)
|
||||
),
|
||||
types.ListRootsResult,
|
||||
)
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_list_changed(self) -> None:
|
||||
"""Send a resource list changed notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ResourceListChangedNotification(
|
||||
method="notifications/resources/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_tool_list_changed(self) -> None:
|
||||
"""Send a tool list changed notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ToolListChangedNotification(
|
||||
method="notifications/tools/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_prompt_list_changed(self) -> None:
|
||||
"""Send a prompt list changed notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.PromptListChangedNotification(
|
||||
method="notifications/prompts/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
|
||||
await self._incoming_message_stream_writer.send(req)
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
|
||||
return self._incoming_message_stream_reader
|
||||
175
.venv/lib/python3.10/site-packages/mcp/server/sse.py
Normal file
175
.venv/lib/python3.10/site-packages/mcp/server/sse.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
SSE Server Transport Module
|
||||
|
||||
This module implements a Server-Sent Events (SSE) transport layer for MCP servers.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Create an SSE transport at an endpoint
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
# Create Starlette routes for SSE and message handling
|
||||
routes = [
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
]
|
||||
|
||||
# Define handler functions
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
# Create and run Starlette app
|
||||
starlette_app = Starlette(routes=routes)
|
||||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||
```
|
||||
|
||||
See SseServerTransport class documentation for more details.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications,
|
||||
suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests,
|
||||
and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST
|
||||
requests, which should contain client messages that link to a
|
||||
previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[
|
||||
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](0)
|
||||
|
||||
async def sse_writer():
|
||||
logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
body = await request.body()
|
||||
logger.debug(f"Received JSON: {body}")
|
||||
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(body)
|
||||
logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(message)
|
||||
86
.venv/lib/python3.10/site-packages/mcp/server/stdio.py
Normal file
86
.venv/lib/python3.10/site-packages/mcp/server/stdio.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Stdio Server Transport Module
|
||||
|
||||
This module provides functionality for creating an stdio-based transport layer
|
||||
that can be used to communicate with an MCP client through standard input/output
|
||||
streams.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
async def run_server():
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
# read_stream contains incoming JSONRPCMessages from stdin
|
||||
# write_stream allows sending JSONRPCMessages to stdout
|
||||
server = await create_my_server()
|
||||
await server.run(read_stream, write_stream, init_options)
|
||||
|
||||
anyio.run(run_server)
|
||||
```
|
||||
"""
|
||||
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from io import TextIOWrapper
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_server(
|
||||
stdin: anyio.AsyncFile[str] | None = None,
|
||||
stdout: anyio.AsyncFile[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Server transport for stdio: this communicates with an MCP client by reading
|
||||
from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close
|
||||
# standard process handles. Encoding of stdin/stdout as text streams on
|
||||
# python is platform-dependent (Windows is particularly problematic), so we
|
||||
# re-wrap the underlying binary stream to ensure UTF-8.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8"))
|
||||
if not stdout:
|
||||
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def stdin_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for line in stdin:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdout_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await stdout.write(json + "\n")
|
||||
await stdout.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(stdin_reader)
|
||||
tg.start_soon(stdout_writer)
|
||||
yield read_stream, write_stream
|
||||
60
.venv/lib/python3.10/site-packages/mcp/server/websocket.py
Normal file
60
.venv/lib/python3.10/site-packages/mcp/server/websocket.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic_core import ValidationError
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
import mcp.types as types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
"""
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be
|
||||
used with a framework like Starlette and a server like Hypercorn.
|
||||
"""
|
||||
|
||||
websocket = WebSocket(scope, receive, send)
|
||||
await websocket.accept(subprotocol="mcp")
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def ws_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for msg in websocket.iter_text():
|
||||
try:
|
||||
client_message = types.JSONRPCMessage.model_validate_json(msg)
|
||||
except ValidationError as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(client_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async def ws_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
obj = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await websocket.send_text(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
yield (read_stream, write_stream)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
18
.venv/lib/python3.10/site-packages/mcp/shared/context.py
Normal file
18
.venv/lib/python3.10/site-packages/mcp/shared/context.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
14
.venv/lib/python3.10/site-packages/mcp/shared/exceptions.py
Normal file
14
.venv/lib/python3.10/site-packages/mcp/shared/exceptions.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from mcp.types import ErrorData
|
||||
|
||||
|
||||
class McpError(Exception):
|
||||
"""
|
||||
Exception type raised when an error arrives over an MCP connection.
|
||||
"""
|
||||
|
||||
error: ErrorData
|
||||
|
||||
def __init__(self, error: ErrorData):
|
||||
"""Initialize McpError."""
|
||||
super().__init__(error.message)
|
||||
self.error = error
|
||||
102
.venv/lib/python3.10/site-packages/mcp/shared/memory.py
Normal file
102
.venv/lib/python3.10/site-packages/mcp/shared/memory.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
In-memory transports
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp.client.session import (
|
||||
ClientSession,
|
||||
ListRootsFnT,
|
||||
LoggingFnT,
|
||||
MessageHandlerFnT,
|
||||
SamplingFnT,
|
||||
)
|
||||
from mcp.server import Server
|
||||
from mcp.types import JSONRPCMessage
|
||||
|
||||
MessageStream = tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage],
|
||||
]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_client_server_memory_streams() -> (
|
||||
AsyncGenerator[tuple[MessageStream, MessageStream], None]
|
||||
):
|
||||
"""
|
||||
Creates a pair of bidirectional memory streams for client-server communication.
|
||||
|
||||
Returns:
|
||||
A tuple of (client_streams, server_streams) where each is a tuple of
|
||||
(read_stream, write_stream)
|
||||
"""
|
||||
# Create streams for both directions
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
|
||||
client_streams = (server_to_client_receive, client_to_server_send)
|
||||
server_streams = (client_to_server_receive, server_to_client_send)
|
||||
|
||||
async with (
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
):
|
||||
yield client_streams, server_streams
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server[Any],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
async with create_client_server_memory_streams() as (
|
||||
client_streams,
|
||||
server_streams,
|
||||
):
|
||||
client_read, client_write = client_streams
|
||||
server_read, server_write = server_streams
|
||||
|
||||
# Create a cancel scope for the server task
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(
|
||||
lambda: server.run(
|
||||
server_read,
|
||||
server_write,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=raise_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async with ClientSession(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
logging_callback=logging_callback,
|
||||
message_handler=message_handler,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
84
.venv/lib/python3.10/site-packages/mcp/shared/progress.py
Normal file
84
.venv/lib/python3.10/site-packages/mcp/shared/progress.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.context import LifespanContextT, RequestContext
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
ReceiveNotificationT,
|
||||
ReceiveRequestT,
|
||||
SendNotificationT,
|
||||
SendRequestT,
|
||||
SendResultT,
|
||||
)
|
||||
from mcp.types import ProgressToken
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
progress: float
|
||||
total: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressContext(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
]
|
||||
):
|
||||
session: BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
]
|
||||
progress_token: ProgressToken
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress(
|
||||
ctx: RequestContext[
|
||||
BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
LifespanContextT,
|
||||
],
|
||||
total: float | None = None,
|
||||
) -> Generator[
|
||||
ProgressContext[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
if ctx.meta is None or ctx.meta.progressToken is None:
|
||||
raise ValueError("No progress token provided")
|
||||
|
||||
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
||||
try:
|
||||
yield progress_ctx
|
||||
finally:
|
||||
pass
|
||||
394
.venv/lib/python3.10/site-packages/mcp/shared/session.py
Normal file
394
.venv/lib/python3.10/site-packages/mcp/shared/session.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import httpx
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar(
|
||||
"ReceiveNotificationT", ClientNotification, ServerNotification
|
||||
)
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
await resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._cancel_scope.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
if not self._cancel_scope:
|
||||
raise RuntimeError("No active cancel scope")
|
||||
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
if not self.cancelled:
|
||||
self._completed = True
|
||||
|
||||
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
if not self._cancel_scope:
|
||||
raise RuntimeError("No active cancel scope")
|
||||
|
||||
self._cancel_scope.cancel()
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
@property
|
||||
def in_flight(self) -> bool:
|
||||
return not self._completed and not self.cancelled
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancel_scope.cancel_called
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is an async context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._receive_loop)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> bool | None:
|
||||
await self._exit_stack.aclose()
|
||||
# Using BaseSession as a context manager should not block on exit (this
|
||||
# would be very surprising behavior), so make sure to cancel the tasks
|
||||
# in the task group.
|
||||
self._task_group.cancel_scope.cancel()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCResponse | JSONRPCError
|
||||
](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
|
||||
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
|
||||
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
# TODO: Support progress callbacks
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
try:
|
||||
with anyio.fail_after(
|
||||
None
|
||||
if self._read_timeout_seconds is None
|
||||
else self._read_timeout_seconds.total_seconds()
|
||||
):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{self._read_timeout_seconds} seconds."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
async def send_notification(self, notification: SendNotificationT) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
|
||||
|
||||
async def _send_response(
|
||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||
) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
)
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
):
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception):
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. "
|
||||
f"Message was: {message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.root.id, None)
|
||||
if stream:
|
||||
await stream.send(message.root)
|
||||
else:
|
||||
await self._handle_incoming(
|
||||
RuntimeError(
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
)
|
||||
)
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
3
.venv/lib/python3.10/site-packages/mcp/shared/version.py
Normal file
3
.venv/lib/python3.10/site-packages/mcp/shared/version.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)
|
||||
1130
.venv/lib/python3.10/site-packages/mcp/types.py
Normal file
1130
.venv/lib/python3.10/site-packages/mcp/types.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user