structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from .fastmcp import FastMCP
from .lowlevel import NotificationOptions, Server
from .models import InitializationOptions
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]

View File

@@ -0,0 +1,50 @@
import importlib.metadata
import logging
import sys
import anyio
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.types import ServerCapabilities
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("server")
async def receive_loop(session: ServerSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
logger.info("Received message from client: %s", message)
async def main():
version = importlib.metadata.version("mcp")
async with stdio_server() as (read_stream, write_stream):
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)
if __name__ == "__main__":
anyio.run(main, backend="trio")

View File

@@ -0,0 +1,9 @@
"""FastMCP - A more ergonomic interface for MCP servers."""
from importlib.metadata import version
from .server import Context, FastMCP
from .utilities.types import Image
__version__ = version("mcp")
__all__ = ["FastMCP", "Context", "Image"]

View File

@@ -0,0 +1,21 @@
"""Custom exceptions for FastMCP."""
class FastMCPError(Exception):
"""Base error for FastMCP."""
class ValidationError(FastMCPError):
"""Error in validating parameters or return values."""
class ResourceError(FastMCPError):
"""Error in resource operations."""
class ToolError(FastMCPError):
"""Error in tool operations."""
class InvalidSignature(Exception):
"""Invalid signature for use with FastMCP."""

View File

@@ -0,0 +1,4 @@
from .base import Prompt
from .manager import PromptManager
__all__ = ["Prompt", "PromptManager"]

View File

@@ -0,0 +1,167 @@
"""Base classes for FastMCP prompts."""
import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import EmbeddedResource, ImageContent, TextContent
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
class Message(BaseModel):
"""Base class for all prompt messages."""
role: Literal["user", "assistant"]
content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
class UserMessage(Message):
"""A message from the user."""
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
class AssistantMessage(Message):
"""A message from the assistant."""
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
)
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
class PromptArgument(BaseModel):
"""An argument that can be passed to a prompt."""
name: str = Field(description="Name of the argument")
description: str | None = Field(
None, description="Description of what the argument does"
)
required: bool = Field(
default=False, description="Whether the argument is required"
)
class Prompt(BaseModel):
"""A prompt template that can be rendered with parameters."""
name: str = Field(description="Name of the prompt")
description: str | None = Field(
None, description="Description of what the prompt does"
)
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
description: str | None = None,
) -> "Prompt":
"""Create a Prompt from a function.
The function can return:
- A string (converted to a message)
- A Message object
- A dict (converted to a message)
- A sequence of any of the above
"""
func_name = name or fn.__name__
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")
# Get schema from TypeAdapter - will fail if function isn't properly typed
parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments
arguments: list[PromptArgument] = []
if "properties" in parameters:
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
arguments.append(
PromptArgument(
name=param_name,
description=param.get("description"),
required=required,
)
)
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
name=func_name,
description=description or fn.__doc__ or "",
arguments=arguments,
fn=fn,
)
async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
"""Render the prompt with arguments."""
# Validate required arguments
if self.arguments:
required = {arg.name for arg in self.arguments if arg.required}
provided = set(arguments or {})
missing = required - provided
if missing:
raise ValueError(f"Missing required arguments: {missing}")
try:
# Call function and check if result is a coroutine
result = self.fn(**(arguments or {}))
if inspect.iscoroutine(result):
result = await result
# Validate messages
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else:
content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=content))
except Exception:
raise ValueError(
f"Could not convert prompt result to message: {msg}"
)
return messages
except Exception as e:
raise ValueError(f"Error rendering prompt {self.name}: {e}")

View File

@@ -0,0 +1,50 @@
"""Prompt management functionality."""
from typing import Any
from mcp.server.fastmcp.prompts.base import Message, Prompt
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__)
class PromptManager:
"""Manages FastMCP prompts."""
def __init__(self, warn_on_duplicate_prompts: bool = True):
self._prompts: dict[str, Prompt] = {}
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
def get_prompt(self, name: str) -> Prompt | None:
"""Get prompt by name."""
return self._prompts.get(name)
def list_prompts(self) -> list[Prompt]:
"""List all registered prompts."""
return list(self._prompts.values())
def add_prompt(
self,
prompt: Prompt,
) -> Prompt:
"""Add a prompt to the manager."""
# Check for duplicates
existing = self._prompts.get(prompt.name)
if existing:
if self.warn_on_duplicate_prompts:
logger.warning(f"Prompt already exists: {prompt.name}")
return existing
self._prompts[prompt.name] = prompt
return prompt
async def render_prompt(
self, name: str, arguments: dict[str, Any] | None = None
) -> list[Message]:
"""Render a prompt by name with arguments."""
prompt = self.get_prompt(name)
if not prompt:
raise ValueError(f"Unknown prompt: {name}")
return await prompt.render(arguments)

View File

@@ -0,0 +1,33 @@
"""Prompt management functionality."""
from mcp.server.fastmcp.prompts.base import Prompt
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__)
class PromptManager:
"""Manages FastMCP prompts."""
def __init__(self, warn_on_duplicate_prompts: bool = True):
self._prompts: dict[str, Prompt] = {}
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
def add_prompt(self, prompt: Prompt) -> Prompt:
"""Add a prompt to the manager."""
logger.debug(f"Adding prompt: {prompt.name}")
existing = self._prompts.get(prompt.name)
if existing:
if self.warn_on_duplicate_prompts:
logger.warning(f"Prompt already exists: {prompt.name}")
return existing
self._prompts[prompt.name] = prompt
return prompt
def get_prompt(self, name: str) -> Prompt | None:
"""Get prompt by name."""
return self._prompts.get(name)
def list_prompts(self) -> list[Prompt]:
"""List all registered prompts."""
return list(self._prompts.values())

View File

@@ -0,0 +1,23 @@
from .base import Resource
from .resource_manager import ResourceManager
from .templates import ResourceTemplate
from .types import (
BinaryResource,
DirectoryResource,
FileResource,
FunctionResource,
HttpResource,
TextResource,
)
__all__ = [
"Resource",
"TextResource",
"BinaryResource",
"FunctionResource",
"FileResource",
"HttpResource",
"DirectoryResource",
"ResourceTemplate",
"ResourceManager",
]

View File

@@ -0,0 +1,48 @@
"""Base classes and interfaces for FastMCP resources."""
import abc
from typing import Annotated
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
UrlConstraints,
ValidationInfo,
field_validator,
)
class Resource(BaseModel, abc.ABC):
"""Base class for all resources."""
model_config = ConfigDict(validate_default=True)
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
default=..., description="URI of the resource"
)
name: str | None = Field(description="Name of the resource", default=None)
description: str | None = Field(
description="Description of the resource", default=None
)
mime_type: str = Field(
default="text/plain",
description="MIME type of the resource content",
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$",
)
@field_validator("name", mode="before")
@classmethod
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
"""Set default name from URI if not provided."""
if name:
return name
if uri := info.data.get("uri"):
return str(uri)
raise ValueError("Either name or uri must be provided")
@abc.abstractmethod
async def read(self) -> str | bytes:
"""Read the resource content."""
pass

View File

@@ -0,0 +1,95 @@
"""Resource manager functionality."""
from collections.abc import Callable
from typing import Any
from pydantic import AnyUrl
from mcp.server.fastmcp.resources.base import Resource
from mcp.server.fastmcp.resources.templates import ResourceTemplate
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__)
class ResourceManager:
"""Manages FastMCP resources."""
def __init__(self, warn_on_duplicate_resources: bool = True):
self._resources: dict[str, Resource] = {}
self._templates: dict[str, ResourceTemplate] = {}
self.warn_on_duplicate_resources = warn_on_duplicate_resources
def add_resource(self, resource: Resource) -> Resource:
"""Add a resource to the manager.
Args:
resource: A Resource instance to add
Returns:
The added resource. If a resource with the same URI already exists,
returns the existing resource.
"""
logger.debug(
"Adding resource",
extra={
"uri": resource.uri,
"type": type(resource).__name__,
"resource_name": resource.name,
},
)
existing = self._resources.get(str(resource.uri))
if existing:
if self.warn_on_duplicate_resources:
logger.warning(f"Resource already exists: {resource.uri}")
return existing
self._resources[str(resource.uri)] = resource
return resource
def add_template(
self,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> ResourceTemplate:
"""Add a template from a function."""
template = ResourceTemplate.from_function(
fn,
uri_template=uri_template,
name=name,
description=description,
mime_type=mime_type,
)
self._templates[template.uri_template] = template
return template
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
"""Get resource by URI, checking concrete resources first, then templates."""
uri_str = str(uri)
logger.debug("Getting resource", extra={"uri": uri_str})
# First check concrete resources
if resource := self._resources.get(uri_str):
return resource
# Then check templates
for template in self._templates.values():
if params := template.matches(uri_str):
try:
return await template.create_resource(uri_str, params)
except Exception as e:
raise ValueError(f"Error creating resource from template: {e}")
raise ValueError(f"Unknown resource: {uri}")
def list_resources(self) -> list[Resource]:
"""List all registered resources."""
logger.debug("Listing resources", extra={"count": len(self._resources)})
return list(self._resources.values())
def list_templates(self) -> list[ResourceTemplate]:
"""List all registered templates."""
logger.debug("Listing templates", extra={"count": len(self._templates)})
return list(self._templates.values())

View File

@@ -0,0 +1,85 @@
"""Resource template functionality."""
from __future__ import annotations
import inspect
import re
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
class ResourceTemplate(BaseModel):
"""A template for dynamically creating resources."""
uri_template: str = Field(
description="URI template with parameters (e.g. weather://{city}/current)"
)
name: str = Field(description="Name of the resource")
description: str | None = Field(description="Description of what the resource does")
mime_type: str = Field(
default="text/plain", description="MIME type of the resource content"
)
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(
description="JSON schema for function parameters"
)
@classmethod
def from_function(
cls,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")
# Get schema from TypeAdapter - will fail if function isn't properly typed
parameters = TypeAdapter(fn).json_schema()
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
uri_template=uri_template,
name=func_name,
description=description or fn.__doc__ or "",
mime_type=mime_type or "text/plain",
fn=fn,
parameters=parameters,
)
def matches(self, uri: str) -> dict[str, Any] | None:
"""Check if URI matches template and extract parameters."""
# Convert template to regex pattern
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
match = re.match(f"^{pattern}$", uri)
if match:
return match.groupdict()
return None
async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
"""Create a resource from the template with the given parameters."""
try:
# Call function and check if result is a coroutine
result = self.fn(**params)
if inspect.iscoroutine(result):
result = await result
return FunctionResource(
uri=uri, # type: ignore
name=self.name,
description=self.description,
mime_type=self.mime_type,
fn=lambda: result, # Capture result in closure
)
except Exception as e:
raise ValueError(f"Error creating resource from template: {e}")

View File

@@ -0,0 +1,185 @@
"""Concrete resource implementations."""
import inspect
import json
from collections.abc import Callable
from pathlib import Path
from typing import Any
import anyio
import anyio.to_thread
import httpx
import pydantic.json
import pydantic_core
from pydantic import Field, ValidationInfo
from mcp.server.fastmcp.resources.base import Resource
class TextResource(Resource):
"""A resource that reads from a string."""
text: str = Field(description="Text content of the resource")
async def read(self) -> str:
"""Read the text content."""
return self.text
class BinaryResource(Resource):
"""A resource that reads from bytes."""
data: bytes = Field(description="Binary content of the resource")
async def read(self) -> bytes:
"""Read the binary content."""
return self.data
class FunctionResource(Resource):
"""A resource that defers data loading by wrapping a function.
The function is only called when the resource is read, allowing for lazy loading
of potentially expensive data. This is particularly useful when listing resources,
as the function won't be called until the resource is actually accessed.
The function can return:
- str for text content (default)
- bytes for binary content
- other types will be converted to JSON
"""
fn: Callable[[], Any] = Field(exclude=True)
async def read(self) -> str | bytes:
"""Read the resource by calling the wrapped function."""
try:
result = (
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
)
if isinstance(result, Resource):
return await result.read()
if isinstance(result, bytes):
return result
if isinstance(result, str):
return result
try:
return json.dumps(pydantic_core.to_jsonable_python(result))
except (TypeError, pydantic_core.PydanticSerializationError):
# If JSON serialization fails, try str()
return str(result)
except Exception as e:
raise ValueError(f"Error reading resource {self.uri}: {e}")
class FileResource(Resource):
"""A resource that reads from a file.
Set is_binary=True to read file as binary data instead of text.
"""
path: Path = Field(description="Path to the file")
is_binary: bool = Field(
default=False,
description="Whether to read the file as binary data",
)
mime_type: str = Field(
default="text/plain",
description="MIME type of the resource content",
)
@pydantic.field_validator("path")
@classmethod
def validate_absolute_path(cls, path: Path) -> Path:
"""Ensure path is absolute."""
if not path.is_absolute():
raise ValueError("Path must be absolute")
return path
@pydantic.field_validator("is_binary")
@classmethod
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
"""Set is_binary based on mime_type if not explicitly set."""
if is_binary:
return True
mime_type = info.data.get("mime_type", "text/plain")
return not mime_type.startswith("text/")
async def read(self) -> str | bytes:
"""Read the file content."""
try:
if self.is_binary:
return await anyio.to_thread.run_sync(self.path.read_bytes)
return await anyio.to_thread.run_sync(self.path.read_text)
except Exception as e:
raise ValueError(f"Error reading file {self.path}: {e}")
class HttpResource(Resource):
"""A resource that reads from an HTTP endpoint."""
url: str = Field(description="URL to fetch content from")
mime_type: str = Field(
default="application/json", description="MIME type of the resource content"
)
async def read(self) -> str | bytes:
"""Read the HTTP content."""
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
return response.text
class DirectoryResource(Resource):
"""A resource that lists files in a directory."""
path: Path = Field(description="Path to the directory")
recursive: bool = Field(
default=False, description="Whether to list files recursively"
)
pattern: str | None = Field(
default=None, description="Optional glob pattern to filter files"
)
mime_type: str = Field(
default="application/json", description="MIME type of the resource content"
)
@pydantic.field_validator("path")
@classmethod
def validate_absolute_path(cls, path: Path) -> Path:
"""Ensure path is absolute."""
if not path.is_absolute():
raise ValueError("Path must be absolute")
return path
def list_files(self) -> list[Path]:
"""List files in the directory."""
if not self.path.exists():
raise FileNotFoundError(f"Directory not found: {self.path}")
if not self.path.is_dir():
raise NotADirectoryError(f"Not a directory: {self.path}")
try:
if self.pattern:
return (
list(self.path.glob(self.pattern))
if not self.recursive
else list(self.path.rglob(self.pattern))
)
return (
list(self.path.glob("*"))
if not self.recursive
else list(self.path.rglob("*"))
)
except Exception as e:
raise ValueError(f"Error listing directory {self.path}: {e}")
async def read(self) -> str: # Always returns JSON string
"""Read the directory listing."""
try:
files = await anyio.to_thread.run_sync(self.list_files)
file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()]
return json.dumps({"files": file_list}, indent=2)
except Exception as e:
raise ValueError(f"Error reading directory {self.path}: {e}")

View File

@@ -0,0 +1,713 @@
"""FastMCP - A more ergonomic interface for MCP servers."""
from __future__ import annotations as _annotations
import inspect
import json
import re
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Generic, Literal
import anyio
import pydantic_core
import uvicorn
from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.lowlevel.server import Server as MCPServer
from mcp.server.lowlevel.server import lifespan as default_lifespan
from mcp.server.session import ServerSession, ServerSessionT
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.types import (
AnyFunction,
EmbeddedResource,
GetPromptResult,
ImageContent,
TextContent,
)
from mcp.types import Prompt as MCPPrompt
from mcp.types import PromptArgument as MCPPromptArgument
from mcp.types import Resource as MCPResource
from mcp.types import ResourceTemplate as MCPResourceTemplate
from mcp.types import Tool as MCPTool
logger = get_logger(__name__)
class Settings(BaseSettings, Generic[LifespanResultT]):
"""FastMCP server settings.
All settings can be configured via environment variables with the prefix FASTMCP_.
For example, FASTMCP_DEBUG=true will set debug=True.
"""
model_config = SettingsConfigDict(
env_prefix="FASTMCP_",
env_file=".env",
extra="ignore",
)
# Server settings
debug: bool = False
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
# HTTP settings
host: str = "0.0.0.0"
port: int = 8000
sse_path: str = "/sse"
message_path: str = "/messages/"
# resource settings
warn_on_duplicate_resources: bool = True
# tool settings
warn_on_duplicate_tools: bool = True
# prompt settings
warn_on_duplicate_prompts: bool = True
dependencies: list[str] = Field(
default_factory=list,
description="List of dependencies to install in the server environment",
)
lifespan: (
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context
return wrap
class FastMCP:
def __init__(
self, name: str | None = None, instructions: str | None = None, **settings: Any
):
self.settings = Settings(**settings)
self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
lifespan=lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan,
)
self._tool_manager = ToolManager(
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
)
self._resource_manager = ResourceManager(
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
)
self._prompt_manager = PromptManager(
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
)
self.dependencies = self.settings.dependencies
# Set up MCP protocol handlers
self._setup_handlers()
# Configure logging
configure_logging(self.settings.log_level)
@property
def name(self) -> str:
return self._mcp_server.name
@property
def instructions(self) -> str | None:
return self._mcp_server.instructions
def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None:
"""Run the FastMCP server. Note this is a synchronous function.
Args:
transport: Transport protocol to use ("stdio" or "sse")
"""
TRANSPORTS = Literal["stdio", "sse"]
if transport not in TRANSPORTS.__args__: # type: ignore
raise ValueError(f"Unknown transport: {transport}")
if transport == "stdio":
anyio.run(self.run_stdio_async)
else: # transport == "sse"
anyio.run(self.run_sse_async)
def _setup_handlers(self) -> None:
"""Set up core MCP protocol handlers."""
self._mcp_server.list_tools()(self.list_tools)
self._mcp_server.call_tool()(self.call_tool)
self._mcp_server.list_resources()(self.list_resources)
self._mcp_server.read_resource()(self.read_resource)
self._mcp_server.list_prompts()(self.list_prompts)
self._mcp_server.get_prompt()(self.get_prompt)
self._mcp_server.list_resource_templates()(self.list_resource_templates)
async def list_tools(self) -> list[MCPTool]:
"""List all available tools."""
tools = self._tool_manager.list_tools()
return [
MCPTool(
name=info.name,
description=info.description,
inputSchema=info.parameters,
)
for info in tools
]
def get_context(self) -> Context[ServerSession, object]:
"""
Returns a Context object. Note that the context will only be valid
during a request; outside a request, most methods will error.
"""
try:
request_context = self._mcp_server.request_context
except LookupError:
request_context = None
return Context(request_context=request_context, fastmcp=self)
async def call_tool(
self, name: str, arguments: dict[str, Any]
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Call a tool by name with arguments."""
context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context)
converted_result = _convert_to_content(result)
return converted_result
async def list_resources(self) -> list[MCPResource]:
"""List all available resources."""
resources = self._resource_manager.list_resources()
return [
MCPResource(
uri=resource.uri,
name=resource.name or "",
description=resource.description,
mimeType=resource.mime_type,
)
for resource in resources
]
async def list_resource_templates(self) -> list[MCPResourceTemplate]:
templates = self._resource_manager.list_templates()
return [
MCPResourceTemplate(
uriTemplate=template.uri_template,
name=template.name,
description=template.description,
)
for template in templates
]
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
"""Read a resource by URI."""
resource = await self._resource_manager.get_resource(uri)
if not resource:
raise ResourceError(f"Unknown resource: {uri}")
try:
content = await resource.read()
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
except Exception as e:
logger.error(f"Error reading resource {uri}: {e}")
raise ResourceError(str(e))
def add_tool(
self,
fn: AnyFunction,
name: str | None = None,
description: str | None = None,
) -> None:
"""Add a tool to the server.
The tool function can optionally request a Context object by adding a parameter
with the Context type annotation. See the @tool decorator for examples.
Args:
fn: The function to register as a tool
name: Optional name for the tool (defaults to function name)
description: Optional description of what the tool does
"""
self._tool_manager.add_tool(fn, name=name, description=description)
def tool(
self, name: str | None = None, description: str | None = None
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a tool.
Tools can optionally request a Context object by adding a parameter with the
Context type annotation. The context provides access to MCP capabilities like
logging, progress reporting, and resource access.
Args:
name: Optional name for the tool (defaults to function name)
description: Optional description of what the tool does
Example:
@server.tool()
def my_tool(x: int) -> str:
return str(x)
@server.tool()
def tool_with_context(x: int, ctx: Context) -> str:
ctx.info(f"Processing {x}")
return str(x)
@server.tool()
async def async_tool(x: int, context: Context) -> str:
await context.report_progress(50, 100)
return str(x)
"""
# Check if user passed function directly instead of calling decorator
if callable(name):
raise TypeError(
"The @tool decorator was used incorrectly. "
"Did you forget to call it? Use @tool() instead of @tool"
)
def decorator(fn: AnyFunction) -> AnyFunction:
self.add_tool(fn, name=name, description=description)
return fn
return decorator
def add_resource(self, resource: Resource) -> None:
"""Add a resource to the server.
Args:
resource: A Resource instance to add
"""
self._resource_manager.add_resource(resource)
def resource(
self,
uri: str,
*,
name: str | None = None,
description: str | None = None,
mime_type: str | None = None,
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a function as a resource.
The function will be called when the resource is read to generate its content.
The function can return:
- str for text content
- bytes for binary content
- other types will be converted to JSON
If the URI contains parameters (e.g. "resource://{param}") or the function
has parameters, it will be registered as a template resource.
Args:
uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}")
name: Optional name for the resource
description: Optional description of the resource
mime_type: Optional MIME type for the resource
Example:
@server.resource("resource://my-resource")
def get_data() -> str:
return "Hello, world!"
@server.resource("resource://my-resource")
async get_data() -> str:
data = await fetch_data()
return f"Hello, world! {data}"
@server.resource("resource://{city}/weather")
def get_weather(city: str) -> str:
return f"Weather for {city}"
@server.resource("resource://{city}/weather")
async def get_weather(city: str) -> str:
data = await fetch_weather(city)
return f"Weather for {city}: {data}"
"""
# Check if user passed function directly instead of calling decorator
if callable(uri):
raise TypeError(
"The @resource decorator was used incorrectly. "
"Did you forget to call it? Use @resource('uri') instead of @resource"
)
def decorator(fn: AnyFunction) -> AnyFunction:
# Check if this should be a template
has_uri_params = "{" in uri and "}" in uri
has_func_params = bool(inspect.signature(fn).parameters)
if has_uri_params or has_func_params:
# Validate that URI params match function params
uri_params = set(re.findall(r"{(\w+)}", uri))
func_params = set(inspect.signature(fn).parameters.keys())
if uri_params != func_params:
raise ValueError(
f"Mismatch between URI parameters {uri_params} "
f"and function parameters {func_params}"
)
# Register as template
self._resource_manager.add_template(
fn=fn,
uri_template=uri,
name=name,
description=description,
mime_type=mime_type or "text/plain",
)
else:
# Register as regular resource
resource = FunctionResource(
uri=AnyUrl(uri),
name=name,
description=description,
mime_type=mime_type or "text/plain",
fn=fn,
)
self.add_resource(resource)
return fn
return decorator
def add_prompt(self, prompt: Prompt) -> None:
"""Add a prompt to the server.
Args:
prompt: A Prompt instance to add
"""
self._prompt_manager.add_prompt(prompt)
def prompt(
self, name: str | None = None, description: str | None = None
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a prompt.
Args:
name: Optional name for the prompt (defaults to function name)
description: Optional description of what the prompt does
Example:
@server.prompt()
def analyze_table(table_name: str) -> list[Message]:
schema = read_table_schema(table_name)
return [
{
"role": "user",
"content": f"Analyze this schema:\n{schema}"
}
]
@server.prompt()
async def analyze_file(path: str) -> list[Message]:
content = await read_file(path)
return [
{
"role": "user",
"content": {
"type": "resource",
"resource": {
"uri": f"file://{path}",
"text": content
}
}
}
]
"""
# Check if user passed function directly instead of calling decorator
if callable(name):
raise TypeError(
"The @prompt decorator was used incorrectly. "
"Did you forget to call it? Use @prompt() instead of @prompt"
)
def decorator(func: AnyFunction) -> AnyFunction:
prompt = Prompt.from_function(func, name=name, description=description)
self.add_prompt(prompt)
return func
return decorator
async def run_stdio_async(self) -> None:
"""Run the server using stdio transport."""
async with stdio_server() as (read_stream, write_stream):
await self._mcp_server.run(
read_stream,
write_stream,
self._mcp_server.create_initialization_options(),
)
async def run_sse_async(self) -> None:
"""Run the server using SSE transport."""
starlette_app = self.sse_app()
config = uvicorn.Config(
starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
def sse_app(self) -> Starlette:
"""Return an instance of the SSE server app."""
sse = SseServerTransport(self.settings.message_path)
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
request.receive,
request._send, # type: ignore[reportPrivateUsage]
) as streams:
await self._mcp_server.run(
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
)
return Starlette(
debug=self.settings.debug,
routes=[
Route(self.settings.sse_path, endpoint=handle_sse),
Mount(self.settings.message_path, app=sse.handle_post_message),
],
)
async def list_prompts(self) -> list[MCPPrompt]:
"""List all available prompts."""
prompts = self._prompt_manager.list_prompts()
return [
MCPPrompt(
name=prompt.name,
description=prompt.description,
arguments=[
MCPPromptArgument(
name=arg.name,
description=arg.description,
required=arg.required,
)
for arg in (prompt.arguments or [])
],
)
for prompt in prompts
]
async def get_prompt(
self, name: str, arguments: dict[str, Any] | None = None
) -> GetPromptResult:
"""Get a prompt by name with arguments."""
try:
messages = await self._prompt_manager.render_prompt(name, arguments)
return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages))
except Exception as e:
logger.error(f"Error getting prompt {name}: {e}")
raise ValueError(str(e))
def _convert_to_content(
result: Any,
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Convert a result to a sequence of content objects."""
if result is None:
return []
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]
if isinstance(result, Image):
return [result.to_image_content()]
if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if not isinstance(result, str):
try:
result = json.dumps(pydantic_core.to_jsonable_python(result))
except Exception:
result = str(result)
return [TextContent(type="text", text=result)]
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
"""Context object providing access to MCP capabilities.
This provides a cleaner interface to MCP's RequestContext functionality.
It gets injected into tool and resource functions that request it via type hints.
To use context in a tool function, add a parameter with the Context type annotation:
```python
@server.tool()
def my_tool(x: int, ctx: Context) -> str:
# Log messages to the client
ctx.info(f"Processing {x}")
ctx.debug("Debug info")
ctx.warning("Warning message")
ctx.error("Error message")
# Report progress
ctx.report_progress(50, 100)
# Access resources
data = ctx.read_resource("resource://data")
# Get request info
request_id = ctx.request_id
client_id = ctx.client_id
return str(x)
```
The context parameter name can be anything as long as it's annotated with Context.
The context is optional - tools that don't need it can omit the parameter.
"""
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_fastmcp: FastMCP | None
def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self._request_context = request_context
self._fastmcp = fastmcp
@property
def fastmcp(self) -> FastMCP:
"""Access to the FastMCP server."""
if self._fastmcp is None:
raise ValueError("Context is not available outside of a request")
return self._fastmcp
@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")
return self._request_context
async def report_progress(
self, progress: float, total: float | None = None
) -> None:
"""Report progress for the current operation.
Args:
progress: Current progress value e.g. 24
total: Optional total value e.g. 100
"""
progress_token = (
self.request_context.meta.progressToken
if self.request_context.meta
else None
)
if progress_token is None:
return
await self.request_context.session.send_progress_notification(
progress_token=progress_token, progress=progress, total=total
)
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
"""Read a resource by URI.
Args:
uri: Resource URI to read
Returns:
The resource content as either text or bytes
"""
assert (
self._fastmcp is not None
), "Context is not available outside of a request"
return await self._fastmcp.read_resource(uri)
async def log(
self,
level: Literal["debug", "info", "warning", "error"],
message: str,
*,
logger_name: str | None = None,
) -> None:
"""Send a log message to the client.
Args:
level: Log level (debug, info, warning, error)
message: Log message
logger_name: Optional logger name
**extra: Additional structured data to include
"""
await self.request_context.session.send_log_message(
level=level, data=message, logger=logger_name
)
@property
def client_id(self) -> str | None:
"""Get the client ID if available."""
return (
getattr(self.request_context.meta, "client_id", None)
if self.request_context.meta
else None
)
@property
def request_id(self) -> str:
"""Get the unique ID for this request."""
return str(self.request_context.request_id)
@property
def session(self):
"""Access to the underlying session for advanced usage."""
return self.request_context.session
# Convenience methods for common log levels
async def debug(self, message: str, **extra: Any) -> None:
"""Send a debug log message."""
await self.log("debug", message, **extra)
async def info(self, message: str, **extra: Any) -> None:
"""Send an info log message."""
await self.log("info", message, **extra)
async def warning(self, message: str, **extra: Any) -> None:
"""Send a warning log message."""
await self.log("warning", message, **extra)
async def error(self, message: str, **extra: Any) -> None:
"""Send an error log message."""
await self.log("error", message, **extra)

View File

@@ -0,0 +1,4 @@
from .base import Tool
from .tool_manager import ToolManager
__all__ = ["Tool", "ToolManager"]

View File

@@ -0,0 +1,92 @@
from __future__ import annotations as _annotations
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT
class Tool(BaseModel):
"""Internal tool registration info."""
fn: Callable[..., Any] = Field(exclude=True)
name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool"
" arguments"
)
is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field(
None, description="Name of the kwarg that should receive context"
)
@classmethod
def from_function(
cls,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp import Context
func_name = name or fn.__name__
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")
func_doc = description or fn.__doc__ or ""
is_async = inspect.iscoroutinefunction(fn)
if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param.annotation is Context:
context_kwarg = param_name
break
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()
return cls(
fn=fn,
name=func_name,
description=func_doc,
parameters=parameters,
fn_metadata=func_arg_metadata,
is_async=is_async,
context_kwarg=context_kwarg,
)
async def run(
self,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
) -> Any:
"""Run the tool with arguments."""
try:
return await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
arguments,
{self.context_kwarg: context}
if self.context_kwarg is not None
else None,
)
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e

View File

@@ -0,0 +1,60 @@
from __future__ import annotations as _annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
logger = get_logger(__name__)
class ToolManager:
"""Manages FastMCP tools."""
def __init__(self, warn_on_duplicate_tools: bool = True):
self._tools: dict[str, Tool] = {}
self.warn_on_duplicate_tools = warn_on_duplicate_tools
def get_tool(self, name: str) -> Tool | None:
"""Get tool by name."""
return self._tools.get(name)
def list_tools(self) -> list[Tool]:
"""List all registered tools."""
return list(self._tools.values())
def add_tool(
self,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
) -> Tool:
"""Add a tool to the server."""
tool = Tool.from_function(fn, name=name, description=description)
existing = self._tools.get(tool.name)
if existing:
if self.warn_on_duplicate_tools:
logger.warning(f"Tool already exists: {tool.name}")
return existing
self._tools[tool.name] = tool
return tool
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT] | None = None,
) -> Any:
"""Call a tool by name with arguments."""
tool = self.get_tool(name)
if not tool:
raise ToolError(f"Unknown tool: {name}")
return await tool.run(arguments, context=context)

View File

@@ -0,0 +1 @@
"""FastMCP utility modules."""

View File

@@ -0,0 +1,214 @@
import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import (
Annotated,
Any,
ForwardRef,
)
from pydantic import BaseModel, ConfigDict, Field, WithJsonSchema, create_model
from pydantic._internal._typing_extra import eval_type_backport
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from mcp.server.fastmcp.exceptions import InvalidSignature
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__)
class ArgModelBase(BaseModel):
"""A model representing the arguments to a function."""
def model_dump_one_level(self) -> dict[str, Any]:
"""Return a dict of the model's fields, one level deep.
That is, sub-models etc are not dumped - they are kept as pydantic models.
"""
kwargs: dict[str, Any] = {}
for field_name in self.model_fields.keys():
kwargs[field_name] = getattr(self, field_name)
return kwargs
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class FuncMetadata(BaseModel):
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
# We can add things in the future like
# - Maybe some args are excluded from attempting to parse from JSON
# - Maybe some args are special (like context) for dependency injection
async def call_fn_with_arg_validation(
self,
fn: Callable[..., Any] | Awaitable[Any],
fn_is_async: bool,
arguments_to_validate: dict[str, Any],
arguments_to_pass_directly: dict[str, Any] | None,
) -> Any:
"""Call the given function with arguments validated and injected.
Arguments are first attempted to be parsed from JSON, then validated against
the argument model, before being passed to the function.
"""
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
arguments_parsed_dict |= arguments_to_pass_directly or {}
if fn_is_async:
if isinstance(fn, Awaitable):
return await fn
return await fn(**arguments_parsed_dict)
if isinstance(fn, Callable):
return fn(**arguments_parsed_dict)
raise TypeError("fn must be either Callable or Awaitable")
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
"""Pre-parse data from JSON.
Return a dict with same keys as input but with values parsed from JSON
if appropriate.
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
a string rather than an actual list. Claude desktop is prone to this - in fact
it seems incapable of NOT doing this. For sub-models, it tends to pass
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy
for field_name, _field_info in self.arg_model.model_fields.items():
if field_name not in data.keys():
continue
if isinstance(data[field_name], str):
try:
pre_parsed = json.loads(data[field_name])
except json.JSONDecodeError:
continue # Not JSON - skip
if isinstance(pre_parsed, str | int | float):
# This is likely that the raw value is e.g. `"hello"` which we
# Should really be parsed as '"hello"' in Python - but if we parse
# it as JSON it'll turn into just 'hello'. So we skip it.
continue
new_data[field_name] = pre_parsed
assert new_data.keys() == data.keys()
return new_data
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def func_metadata(
func: Callable[..., Any], skip_names: Sequence[str] = ()
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its
signature.
The use case for this is
```
meta = func_to_pyd(func)
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
return func(**validated_args.model_dump_one_level())
```
**critically** it also provides pre-parse helper to attempt to parse things from
JSON.
Args:
func: The function to convert to a pydantic model
skip_names: A list of parameter names to skip. These will not be included in
the model.
Returns:
A pydantic model representing the function's signature.
"""
sig = _get_typed_signature(func)
params = sig.parameters
dynamic_pydantic_model_params: dict[str, Any] = {}
globalns = getattr(func, "__globals__", {})
for param in params.values():
if param.name.startswith("_"):
raise InvalidSignature(
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
)
if param.name in skip_names:
continue
annotation = param.annotation
# `x: None` / `x: None = None`
if annotation is None:
annotation = Annotated[
None,
Field(
default=param.default
if param.default is not inspect.Parameter.empty
else PydanticUndefined
),
]
# Untyped field
if annotation is inspect.Parameter.empty:
annotation = Annotated[
Any,
Field(),
# 🤷
WithJsonSchema({"title": param.name, "type": "string"}),
]
field_info = FieldInfo.from_annotated_attribute(
_get_typed_annotation(annotation, globalns),
param.default
if param.default is not inspect.Parameter.empty
else PydanticUndefined,
)
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
continue
arguments_model = create_model(
f"{func.__name__}Arguments",
**dynamic_pydantic_model_params,
__base__=ArgModelBase,
)
resp = FuncMetadata(arg_model=arguments_model)
return resp
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type(
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try:
return eval_type_backport(value, globalns, localns), True
except NameError:
return value, False
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation, status = try_eval_type(annotation, globalns, globalns)
# This check and raise could perhaps be skipped, and we (FastMCP) just call
# model_rebuild right before using it 🤷
if status is False:
raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")
return annotation
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""Get function signature while evaluating forward references"""
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=_get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature

View File

@@ -0,0 +1,43 @@
"""Logging utilities for FastMCP."""
import logging
from typing import Literal
def get_logger(name: str) -> logging.Logger:
"""Get a logger nested under MCPnamespace.
Args:
name: the name of the logger, which will be prefixed with 'FastMCP.'
Returns:
a configured logger instance
"""
return logging.getLogger(name)
def configure_logging(
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
) -> None:
"""Configure logging for MCP.
Args:
level: the log level to use
"""
handlers: list[logging.Handler] = []
try:
from rich.console import Console
from rich.logging import RichHandler
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
except ImportError:
pass
if not handlers:
handlers.append(logging.StreamHandler())
logging.basicConfig(
level=level,
format="%(message)s",
handlers=handlers,
)

View File

@@ -0,0 +1,54 @@
"""Common types used across FastMCP."""
import base64
from pathlib import Path
from mcp.types import ImageContent
class Image:
"""Helper class for returning images from tools."""
def __init__(
self,
path: str | Path | None = None,
data: bytes | None = None,
format: str | None = None,
):
if path is None and data is None:
raise ValueError("Either path or data must be provided")
if path is not None and data is not None:
raise ValueError("Only one of path or data can be provided")
self.path = Path(path) if path else None
self.data = data
self._format = format
self._mime_type = self._get_mime_type()
def _get_mime_type(self) -> str:
"""Get MIME type from format or guess from file extension."""
if self._format:
return f"image/{self._format.lower()}"
if self.path:
suffix = self.path.suffix.lower()
return {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}.get(suffix, "application/octet-stream")
return "image/png" # default for raw binary data
def to_image_content(self) -> ImageContent:
"""Convert to MCP ImageContent."""
if self.path:
with open(self.path, "rb") as f:
data = base64.b64encode(f.read()).decode()
elif self.data is not None:
data = base64.b64encode(self.data).decode()
else:
raise ValueError("No image data available")
return ImageContent(type="image", data=data, mimeType=self._mime_type)

View File

@@ -0,0 +1,3 @@
from .server import NotificationOptions, Server
__all__ = ["Server", "NotificationOptions"]

View File

@@ -0,0 +1,9 @@
from dataclasses import dataclass
@dataclass
class ReadResourceContents:
"""Contents returned from a read_resource call."""
content: str | bytes
mime_type: str | None = None

View File

@@ -0,0 +1,590 @@
"""
MCP Server Module
This module provides a framework for creating an MCP (Model Context Protocol) server.
It allows you to easily define and handle various types of requests and notifications
in an asynchronous manner.
Usage:
1. Create a Server instance:
server = Server("your_server_name")
2. Define request handlers using decorators:
@server.list_prompts()
async def handle_list_prompts() -> list[types.Prompt]:
# Implementation
@server.get_prompt()
async def handle_get_prompt(
name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult:
# Implementation
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
# Implementation
@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
# Implementation
@server.list_resource_templates()
async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
# Implementation
3. Define notification handlers if needed:
@server.progress_notification()
async def handle_progress(
progress_token: str | int, progress: float, total: float | None
) -> None:
# Implementation
4. Run the server:
async def main():
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="your_server_name",
server_version="your_version",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
asyncio.run(main())
The Server class provides methods to register handlers for various MCP requests and
notifications. It automatically manages the request context and handles incoming
messages from the client.
"""
from __future__ import annotations as _annotations
import contextvars
import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from typing import Any, Generic, TypeVar
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
logger = logging.getLogger(__name__)
LifespanResultT = TypeVar("LifespanResultT")
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
contextvars.ContextVar("request_ctx")
)
class NotificationOptions:
def __init__(
self,
prompts_changed: bool = False,
resources_changed: bool = False,
tools_changed: bool = False,
):
self.prompts_changed = prompts_changed
self.resources_changed = resources_changed
self.tools_changed = tools_changed
@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.
Args:
server: The server instance this lifespan is managing
Returns:
An empty context object
"""
yield {}
class Server(Generic[LifespanResultT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
):
self.name = name
self.version = version
self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
types.PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
self.notification_options = NotificationOptions()
logger.debug(f"Initializing server '{name}'")
def create_initialization_options(
self,
notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
) -> InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
return version(package)
except Exception:
pass
return "unknown"
return InitializationOptions(
server_name=self.name,
server_version=self.version if self.version else pkg_version("mcp"),
capabilities=self.get_capabilities(
notification_options or NotificationOptions(),
experimental_capabilities or {},
),
instructions=self.instructions,
)
def get_capabilities(
self,
notification_options: NotificationOptions,
experimental_capabilities: dict[str, dict[str, Any]],
) -> types.ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
prompts_capability = None
resources_capability = None
tools_capability = None
logging_capability = None
# Set prompt capabilities if handler exists
if types.ListPromptsRequest in self.request_handlers:
prompts_capability = types.PromptsCapability(
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists
if types.ListResourcesRequest in self.request_handlers:
resources_capability = types.ResourcesCapability(
subscribe=False, listChanged=notification_options.resources_changed
)
# Set tool capabilities if handler exists
if types.ListToolsRequest in self.request_handlers:
tools_capability = types.ToolsCapability(
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists
if types.SetLevelRequest in self.request_handlers:
logging_capability = types.LoggingCapability()
return types.ServerCapabilities(
prompts=prompts_capability,
resources=resources_capability,
tools=tools_capability,
logging=logging_capability,
experimental=experimental_capabilities,
)
@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()
def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
logger.debug("Registering handler for PromptListRequest")
async def handler(_: Any):
prompts = await func()
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
self.request_handlers[types.ListPromptsRequest] = handler
return func
return decorator
def get_prompt(self):
def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
],
):
logger.debug("Registering handler for GetPromptRequest")
async def handler(req: types.GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
return types.ServerResult(prompt_get)
self.request_handlers[types.GetPromptRequest] = handler
return func
return decorator
def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
logger.debug("Registering handler for ListResourcesRequest")
async def handler(_: Any):
resources = await func()
return types.ServerResult(
types.ListResourcesResult(resources=resources)
)
self.request_handlers[types.ListResourcesRequest] = handler
return func
return decorator
def list_resource_templates(self):
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
logger.debug("Registering handler for ListResourceTemplatesRequest")
async def handler(_: Any):
templates = await func()
return types.ServerResult(
types.ListResourceTemplatesResult(resourceTemplates=templates)
)
self.request_handlers[types.ListResourceTemplatesRequest] = handler
return func
return decorator
def read_resource(self):
def decorator(
func: Callable[
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
],
):
logger.debug("Registering handler for ReadResourceRequest")
async def handler(req: types.ReadResourceRequest):
result = await func(req.params.uri)
def create_content(data: str | bytes, mime_type: str | None):
match data:
case str() as data:
return types.TextResourceContents(
uri=req.params.uri,
text=data,
mimeType=mime_type or "text/plain",
)
case bytes() as data:
import base64
return types.BlobResourceContents(
uri=req.params.uri,
blob=base64.b64encode(data).decode(),
mimeType=mime_type or "application/octet-stream",
)
match result:
case str() | bytes() as data:
warnings.warn(
"Returning str or bytes from read_resource is deprecated. "
"Use Iterable[ReadResourceContents] instead.",
DeprecationWarning,
stacklevel=2,
)
content = create_content(data, None)
case Iterable() as contents:
contents_list = [
create_content(content_item.content, content_item.mime_type)
for content_item in contents
]
return types.ServerResult(
types.ReadResourceResult(
contents=contents_list,
)
)
case _:
raise ValueError(
f"Unexpected return type from read_resource: {type(result)}"
)
return types.ServerResult(
types.ReadResourceResult(
contents=[content],
)
)
self.request_handlers[types.ReadResourceRequest] = handler
return func
return decorator
def set_logging_level(self):
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
logger.debug("Registering handler for SetLevelRequest")
async def handler(req: types.SetLevelRequest):
await func(req.params.level)
return types.ServerResult(types.EmptyResult())
self.request_handlers[types.SetLevelRequest] = handler
return func
return decorator
def subscribe_resource(self):
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for SubscribeRequest")
async def handler(req: types.SubscribeRequest):
await func(req.params.uri)
return types.ServerResult(types.EmptyResult())
self.request_handlers[types.SubscribeRequest] = handler
return func
return decorator
def unsubscribe_resource(self):
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for UnsubscribeRequest")
async def handler(req: types.UnsubscribeRequest):
await func(req.params.uri)
return types.ServerResult(types.EmptyResult())
self.request_handlers[types.UnsubscribeRequest] = handler
return func
return decorator
def list_tools(self):
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
logger.debug("Registering handler for ListToolsRequest")
async def handler(_: Any):
tools = await func()
return types.ServerResult(types.ListToolsResult(tools=tools))
self.request_handlers[types.ListToolsRequest] = handler
return func
return decorator
def call_tool(self):
def decorator(
func: Callable[
...,
Awaitable[
Iterable[
types.TextContent | types.ImageContent | types.EmbeddedResource
]
],
],
):
logger.debug("Registering handler for CallToolRequest")
async def handler(req: types.CallToolRequest):
try:
results = await func(req.params.name, (req.params.arguments or {}))
return types.ServerResult(
types.CallToolResult(content=list(results), isError=False)
)
except Exception as e:
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
)
)
self.request_handlers[types.CallToolRequest] = handler
return func
return decorator
def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
):
logger.debug("Registering handler for ProgressNotification")
async def handler(req: types.ProgressNotification):
await func(
req.params.progressToken, req.params.progress, req.params.total
)
self.notification_handlers[types.ProgressNotification] = handler
return func
return decorator
def completion(self):
"""Provides completions for prompts and resource templates"""
def decorator(
func: Callable[
[
types.PromptReference | types.ResourceReference,
types.CompletionArgument,
],
Awaitable[types.Completion | None],
],
):
logger.debug("Registering handler for CompleteRequest")
async def handler(req: types.CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
return types.ServerResult(
types.CompleteResult(
completion=completion
if completion is not None
else types.Completion(values=[], total=None, hasMore=None),
)
)
self.request_handlers[types.CompleteRequest] = handler
return func
return decorator
async def run(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
initialization_options: InitializationOptions,
# When False, exceptions are returned as messages to the client.
# When True, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
async with anyio.create_task_group() as tg:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
tg.start_soon(
self._handle_message,
message,
session,
lifespan_context,
raise_exceptions,
)
async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case (
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
for warning in w:
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
async def _handle_request(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult],
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
):
logger.info(f"Processing request of type {type(req).__name__}")
if type(req) in self.request_handlers:
handler = self.request_handlers[type(req)]
logger.debug(f"Dispatching request of type {type(req).__name__}")
token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
token = request_ctx.set(
RequestContext(
message.request_id,
message.request_meta,
session,
lifespan_context,
)
)
response = await handler(req)
except McpError as err:
response = err.error
except Exception as err:
if raise_exceptions:
raise err
response = types.ErrorData(code=0, message=str(err), data=None)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)
await message.respond(response)
else:
await message.respond(
types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="Method not found",
)
)
logger.debug("Response sent")
async def _handle_notification(self, notify: Any):
if type(notify) in self.notification_handlers:
assert type(notify) in self.notification_handlers
handler = self.notification_handlers[type(notify)]
logger.debug(
f"Dispatching notification of type " f"{type(notify).__name__}"
)
try:
await handler(notify)
except Exception as err:
logger.error(f"Uncaught exception in notification handler: " f"{err}")
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
return types.ServerResult(types.EmptyResult())

View File

@@ -0,0 +1,17 @@
"""
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from pydantic import BaseModel
from mcp.types import (
ServerCapabilities,
)
class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities
instructions: str | None = None

View File

@@ -0,0 +1,317 @@
"""
ServerSession Module
This module provides the ServerSession class, which manages communication between the
server and client in the MCP (Model Context Protocol) framework. It is most commonly
used in MCP servers to interact with the client.
Common usage pattern:
```
server = Server(name)
@server.call_tool()
async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any:
# Check client capabilities before proceeding
if ctx.session.check_client_capability(
types.ClientCapabilities(experimental={"advanced_tools": dict()})
):
# Perform advanced tool operations
result = await perform_advanced_tool_operation(arguments)
else:
# Fall back to basic tool operations
result = await perform_basic_tool_operation(arguments)
return result
@server.list_prompts()
async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
# Access session for any necessary checks or operations
if ctx.session.client_params:
# Customize prompts based on client initialization parameters
return generate_custom_prompts(ctx.session.client_params)
else:
return default_prompts
```
The ServerSession class is typically used internally by the Server class and should not
be instantiated directly by users of the MCP framework.
"""
from enum import Enum
from typing import Any, TypeVar
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.session import (
BaseSession,
RequestResponder,
)
class InitializationState(Enum):
NotInitialized = 1
Initializing = 2
Initialized = 3
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception
)
class ServerSession(
BaseSession[
types.ServerRequest,
types.ServerNotification,
types.ServerResult,
types.ClientRequest,
types.ClientNotification,
]
):
_initialized: InitializationState = InitializationState.NotInitialized
_client_params: types.InitializeRequestParams | None = None
def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(
read_stream, write_stream, types.ClientRequest, types.ClientNotification
)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ServerRequestResponder](0)
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
@property
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
"""Check if the client supports a specific capability."""
if self._client_params is None:
return False
# Get client capabilities from initialization params
client_caps = self._client_params.capabilities
# Check each specified capability in the passed in capability object
if capability.roots is not None:
if client_caps.roots is None:
return False
if capability.roots.listChanged and not client_caps.roots.listChanged:
return False
if capability.sampling is not None:
if client_caps.sampling is None:
return False
if capability.experimental is not None:
if client_caps.experimental is None:
return False
# Check each experimental capability
for exp_key, exp_value in capability.experimental.items():
if (
exp_key not in client_caps.experimental
or client_caps.experimental[exp_key] != exp_value
):
return False
return True
async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root:
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(
"Received request before initialization was complete"
)
async def _received_notification(
self, notification: types.ClientNotification
) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:
case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(
"Received notification before initialization was complete"
)
async def send_log_message(
self, level: types.LoggingLevel, data: Any, logger: str | None = None
) -> None:
"""Send a log message notification."""
await self.send_notification(
types.ServerNotification(
types.LoggingMessageNotification(
method="notifications/message",
params=types.LoggingMessageNotificationParams(
level=level,
data=data,
logger=logger,
),
)
)
)
async def send_resource_updated(self, uri: AnyUrl) -> None:
"""Send a resource updated notification."""
await self.send_notification(
types.ServerNotification(
types.ResourceUpdatedNotification(
method="notifications/resources/updated",
params=types.ResourceUpdatedNotificationParams(uri=uri),
)
)
)
async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
) -> types.CreateMessageResult:
"""Send a sampling/create_message request."""
return await self.send_request(
types.ServerRequest(
types.CreateMessageRequest(
method="sampling/createMessage",
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
),
)
),
types.CreateMessageResult,
)
async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request."""
return await self.send_request(
types.ServerRequest(
types.ListRootsRequest(
method="roots/list",
)
),
types.ListRootsResult,
)
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ServerRequest(
types.PingRequest(
method="ping",
)
),
types.EmptyResult,
)
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ServerNotification(
types.ProgressNotification(
method="notifications/progress",
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
)
)
)
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
types.ServerNotification(
types.ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
types.ServerNotification(
types.ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
types.ServerNotification(
types.PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
await self._incoming_message_stream_writer.send(req)
@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader

View File

@@ -0,0 +1,175 @@
"""
SSE Server Transport Module
This module implements a Server-Sent Events (SSE) transport layer for MCP servers.
Example usage:
```
# Create an SSE transport at an endpoint
sse = SseServerTransport("/messages/")
# Create Starlette routes for SSE and message handling
routes = [
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
# Define handler functions
async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await app.run(
streams[0], streams[1], app.create_initialization_options()
)
# Create and run Starlette app
starlette_app = Starlette(routes=routes)
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
```
See SseServerTransport class documentation for more details.
"""
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
from uuid import UUID, uuid4
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
import mcp.types as types
logger = logging.getLogger(__name__)
class SseServerTransport:
"""
SSE server transport for MCP. This class provides _two_ ASGI applications,
suitable to be used with a framework like Starlette and a server like Hypercorn:
1. connect_sse() is an ASGI application which receives incoming GET requests,
and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST
requests, which should contain client messages that link to a
previously-established SSE session.
"""
_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
"""
super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, Any]
](0)
async def sse_writer():
logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader:
logger.debug(f"Sending message via SSE: {message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return await response(scope, receive, send)
try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return await response(scope, receive, send)
writer = self._read_stream_writers.get(session_id)
if not writer:
logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)
body = await request.body()
logger.debug(f"Received JSON: {body}")
try:
message = types.JSONRPCMessage.model_validate_json(body)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
await writer.send(err)
return
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(message)

View File

@@ -0,0 +1,86 @@
"""
Stdio Server Transport Module
This module provides functionality for creating an stdio-based transport layer
that can be used to communicate with an MCP client through standard input/output
streams.
Example usage:
```
async def run_server():
async with stdio_server() as (read_stream, write_stream):
# read_stream contains incoming JSONRPCMessages from stdin
# write_stream allows sending JSONRPCMessages to stdout
server = await create_my_server()
await server.run(read_stream, write_stream, init_options)
anyio.run(run_server)
```
"""
import sys
from contextlib import asynccontextmanager
from io import TextIOWrapper
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
@asynccontextmanager
async def stdio_server(
stdin: anyio.AsyncFile[str] | None = None,
stdout: anyio.AsyncFile[str] | None = None,
):
"""
Server transport for stdio: this communicates with an MCP client by reading
from the current process' stdin and writing to stdout.
"""
# Purposely not using context managers for these, as we don't want to close
# standard process handles. Encoding of stdin/stdout as text streams on
# python is platform-dependent (Windows is particularly problematic), so we
# re-wrap the underlying binary stream to ensure UTF-8.
if not stdin:
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8"))
if not stdout:
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async def stdin_reader():
try:
async with read_stream_writer:
async for line in stdin:
try:
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async def stdout_writer():
try:
async with write_stream_reader:
async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n")
await stdout.flush()
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async with anyio.create_task_group() as tg:
tg.start_soon(stdin_reader)
tg.start_soon(stdout_writer)
yield read_stream, write_stream

View File

@@ -0,0 +1,60 @@
import logging
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic_core import ValidationError
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket
import mcp.types as types
logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_server(scope: Scope, receive: Receive, send: Send):
"""
WebSocket server transport for MCP. This is an ASGI application, suitable to be
used with a framework like Starlette and a server like Hypercorn.
"""
websocket = WebSocket(scope, receive, send)
await websocket.accept(subprotocol="mcp")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async def ws_reader():
try:
async with read_stream_writer:
async for msg in websocket.iter_text():
try:
client_message = types.JSONRPCMessage.model_validate_json(msg)
except ValidationError as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(client_message)
except anyio.ClosedResourceError:
await websocket.close()
async def ws_writer():
try:
async with write_stream_reader:
async for message in write_stream_reader:
obj = message.model_dump_json(by_alias=True, exclude_none=True)
await websocket.send_text(obj)
except anyio.ClosedResourceError:
await websocket.close()
async with anyio.create_task_group() as tg:
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
yield (read_stream, write_stream)