Upgrade to rmcp 0.10 with Streamable HTTP transport support

- Upgrade rmcp from 0.1.5 to 0.10.0
  - Add Streamable HTTP transport with SSE for remote server deployment
  - Update to MCP protocol version 2025-06-18
  - Add CLI arguments: --transport, --host, --port
  - Fix server identity to show actual package name/version
  - Add comprehensive HTTP transport tests
  - Update documentation with transport modes and usage
  - Bump version to 0.3.0
This commit is contained in:
Gianluca Brigandi 2025-12-05 21:03:03 -08:00
parent c706f1a824
commit 4db8ebe635
11 changed files with 589 additions and 164 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "mcp-server-wazuh"
version = "0.2.5"
version = "0.3.0"
edition = "2021"
description = "Wazuh SIEM MCP Server"
authors = ["Gianluca Brigandi <gbrigand@gmail.com>"]
@ -10,7 +10,7 @@ readme = "README.md"
[dependencies]
wazuh-client = "0.1.8"
rmcp = { version = "0.1.5", features = ["server", "transport-io"] }
rmcp = { version = "0.10", features = ["server", "transport-io"] }
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
serde = { version = "1.0", features = ["derive"] }
@ -18,8 +18,9 @@ serde_json = "1.0"
anyhow = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
schemars = "0.8"
schemars = "1.0"
clap = { version = "4.5", features = ["derive"] }
axum = { version = "0.8", optional = true }
dotenv = "0.15"
thiserror = "2.0"
chrono = "0.4.41"
@ -36,4 +37,9 @@ regex = "1.11"
tokio-test = "0.4"
serde_json = "1.0"
tempfile = "3.0"
reqwest = { version = "0.12", features = ["json"] }
[features]
default = []
http = ["dep:axum", "rmcp/transport-streamable-http-server"]

View File

@ -105,7 +105,12 @@ For enhanced threat intelligence capabilities, the Wazuh MCP Server can be combi
```bash
git clone https://github.com/gbrigandi/mcp-server-wazuh.git
cd mcp-server-wazuh
# Build with stdio transport only (default)
cargo build --release
# Build with HTTP transport support
cargo build --release --features http
```
The binary will be available at `target/release/mcp-server-wazuh`.
@ -218,18 +223,72 @@ The "Required: Yes" indicates that these variables are essential for the server
- Edit the `.env` file with your specific Wazuh API details (e.g. `WAZUH_API_HOST`, `WAZUH_API_PORT`).
3. **Build:**
```bash
# Build with default features (stdio transport only)
cargo build
# Build with HTTP transport support
cargo build --features http
```
4. **Run:**
```bash
# Run with stdio transport (default)
cargo run
# Run with HTTP transport (requires --features http during build)
cargo run --features http -- --transport http
# Or use the run script (which might set up stdio mode):
# ./run.sh
```
## Transport Modes
The Wazuh MCP Server supports two transport modes for communication with MCP clients:
### stdio Transport (Default)
The stdio transport is the default mode, ideal for local integrations where the MCP client launches the server as a child process. Communication occurs via stdin/stdout using JSON-RPC 2.0 messages.
```bash
# Run with stdio transport (default)
mcp-server-wazuh
# Explicit stdio transport
mcp-server-wazuh --transport stdio
```
### Streamable HTTP Transport
The HTTP transport enables remote server deployment, allowing MCP clients to connect over the network. This mode implements the MCP Streamable HTTP specification with Server-Sent Events (SSE) support.
```bash
# Run with HTTP transport on default address (127.0.0.1:8080)
mcp-server-wazuh --transport http
# Run with custom host and port
mcp-server-wazuh --transport http --host 0.0.0.0 --port 3000
```
**HTTP Transport Features:**
- Single `/mcp` endpoint for all MCP communication
- POST requests with JSON-RPC messages
- Server-Sent Events (SSE) for streaming responses
- Session management with `MCP-Session-Id` header
- Protocol version: `2025-06-18` (MCP spec supported by rmcp 0.10)
**Security Note:** By default, HTTP transport binds to `127.0.0.1` (localhost only). When binding to `0.0.0.0` for remote access, ensure proper network security measures (firewall rules, reverse proxy with TLS, etc.) are in place.
### CLI Arguments
| Argument | Description | Default |
|----------|-------------|---------|
| `--transport` | Transport mode: `stdio` or `http` | `stdio` |
| `--host` | HTTP server bind address (only for http transport) | `127.0.0.1` |
| `--port` | HTTP server port (only for http transport) | `8080` |
## Architecture
The server is built using the [rmcp](https://crates.io/crates/rmcp) framework and facilitates communication between MCP clients (e.g., Claude Desktop, IDE extensions) and the Wazuh MCP Server via stdio transport. The server interacts with the Wazuh Indexer and Wazuh Manager APIs to fetch security alerts and other data.
The server is built using the [rmcp](https://crates.io/crates/rmcp) framework (v0.10+) and facilitates communication between MCP clients (e.g., Claude Desktop, IDE extensions) and the Wazuh MCP Server. The server supports both stdio and Streamable HTTP transports and interacts with the Wazuh Indexer and Wazuh Manager APIs to fetch security alerts and other data.
```mermaid
sequenceDiagram
@ -284,7 +343,7 @@ Example interaction flow:
"id": 0,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"protocolVersion": "2025-06-18",
"capabilities": {
"sampling": {},
"roots": { "listChanged": true }
@ -303,7 +362,7 @@ Example interaction flow:
"jsonrpc": "2.0",
"id": 1,
"result": {
"protocolVersion": "2024-11-05",
"protocolVersion": "2025-06-18",
"capabilities": {
"prompts": {},
"resources": {},
@ -311,7 +370,7 @@ Example interaction flow:
},
"serverInfo": {
"name": "mcp-server-wazuh",
"version": "0.2.5"
"version": "0.3.0"
},
"instructions": "This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis.\nAvailable tools:\n- 'get_wazuh_alert_summary': Retrieves a summary of Wazuh security alerts. Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100)."
}

View File

@ -11,11 +11,11 @@
//
// Structure:
// - `main()`: Entry point of the application. Initializes logging (tracing),
// sets up the `WazuhToolsServer`, and starts the MCP server using stdio transport.
// sets up the `WazuhToolsServer`, and starts the MCP server using stdio or HTTP transport.
//
// - `WazuhToolsServer`: The core orchestrator struct that implements the `rmcp::ServerHandler` trait
// and the `#[tool(tool_box)]` attribute. It acts as a facade that delegates tool calls to
// specialized domain modules:
// and uses `#[tool_router]` and `#[tool_handler]` macros. It acts as a facade that delegates
// tool calls to specialized domain modules:
// - Holds instances of domain-specific tool modules (AgentTools, AlertTools, RuleTools, etc.)
// - Its methods, decorated with `#[tool(...)]`, define the MCP tool interface and delegate
// to the appropriate domain module for actual implementation
@ -27,59 +27,6 @@
// - `VulnerabilityTools` (`tools/vulnerabilities.rs`): Processes vulnerability data via Wazuh Manager API
// - `AgentTools` (`tools/agents.rs`): Handles agent management and system information queries
// - `StatsTools` (`tools/stats.rs`): Provides logging, statistics, and cluster health monitoring
// Each module encapsulates:
// - Domain-specific business logic and data formatting
// - Parameter validation and error handling
// - Client interaction patterns for their respective Wazuh APIs
// - Rich output formatting with structured text and emojis
//
// - Tool Parameter Structs (e.g., `GetAlertSummaryParams`):
// - These structs define the expected input parameters for each tool.
// - They use `serde::Deserialize` for parsing input and `schemars::JsonSchema`
// for generating a schema that MCP clients can use to understand how to call the tools.
// - Located within their respective domain modules for better organization
//
// - `wazuh_client` crate:
// - This external crate is used to interact with both the Wazuh Manager API and the Wazuh Indexer API.
// - `WazuhClientFactory` is used to create specific clients (e.g., `WazuhIndexerClient`, `RulesClient`, `AgentsClient`, `LogsClient`, `ClusterClient`, `VulnerabilityClient`).
// - Clients are wrapped in `Arc<Mutex<>>` for thread-safe access across async operations
//
// Workflow:
// 1. Server starts and listens for MCP requests on stdio
// 2. MCP client sends a `call_tool` request to `WazuhToolsServer`
// 3. `WazuhToolsServer` routes the request to the appropriate domain-specific tool module
// 4. The domain module validates parameters, interacts with the relevant Wazuh client, and formats results
// 5. The result (success with formatted data or error) is packaged into a `CallToolResult`
// and sent back to the MCP client via the main server
//
// Exposed Tools:
// The server exposes a set of tools categorized by the Wazuh component they interact with:
//
// Alert Management (via AlertTools):
// - `get_wazuh_alert_summary`: Retrieves a summary of security alerts from the Wazuh Indexer.
//
// Rule Management (via RuleTools):
// - `get_wazuh_rules_summary`: Fetches security rules defined in the Wazuh Manager.
//
// Vulnerability Management (via VulnerabilityTools):
// - `get_wazuh_vulnerability_summary`: Gets vulnerability scan results for a specific agent from the Wazuh Manager.
// - `get_wazuh_critical_vulnerabilities`: Retrieves critical vulnerabilities for an agent.
//
// Agent Management (via AgentTools):
// - `get_wazuh_agents`: Lists active and inactive agents connected to the Wazuh Manager.
// - `get_wazuh_agent_processes`: Retrieves running processes on a specific agent (via Syscollector).
// - `get_wazuh_agent_ports`: Lists open network ports on a specific agent (via Syscollector).
//
// Statistics and Monitoring (via StatsTools):
// - `search_wazuh_manager_logs`: Searches logs generated by the Wazuh Manager.
// - `get_wazuh_manager_error_logs`: Retrieves error-specific logs from the Wazuh Manager.
// - `get_wazuh_log_collector_stats`: Gets log collection statistics for an agent.
// - `get_wazuh_remoted_stats`: Fetches statistics from the Wazuh Manager's remoted daemon.
// - `get_wazuh_weekly_stats`: Retrieves aggregated weekly statistics from the Wazuh Manager.
// - `get_wazuh_cluster_health`: Checks the health status of the Wazuh Manager cluster.
// - `get_wazuh_cluster_nodes`: Lists nodes participating in the Wazuh Manager cluster.
//
// (Detailed parameters and descriptions for each tool are available via the MCP `get_tools` command or in the server's `get_info` response.)
//
// Configuration:
// The server requires the following environment variables to connect to the Wazuh instance:
@ -98,10 +45,15 @@
use clap::Parser;
use dotenv::dotenv;
use rmcp::{
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{CallToolResult, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo},
tool,
tool, tool_handler, tool_router,
transport::stdio,
Error as McpError, ServerHandler, ServiceExt,
ErrorData as McpError, ServerHandler, ServiceExt,
};
#[cfg(feature = "http")]
use rmcp::transport::streamable_http_server::{
session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
};
use std::env;
use std::sync::Arc;
@ -125,8 +77,17 @@ use tools::vulnerabilities::{
#[command(name = "mcp-server-wazuh")]
#[command(about = "Wazuh SIEM MCP Server")]
struct Args {
// Currently only stdio transport is supported
// Future versions may add HTTP-SSE transport
/// Transport mode: stdio or http
#[arg(long, default_value = "stdio")]
transport: String,
/// HTTP server bind address (only for http transport)
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// HTTP server port (only for http transport)
#[arg(long, default_value = "8080")]
port: u16,
}
#[derive(Clone)]
@ -136,9 +97,9 @@ struct WazuhToolsServer {
rule_tools: RuleTools,
stats_tools: StatsTools,
vulnerability_tools: VulnerabilityTools,
tool_router: ToolRouter<Self>,
}
#[tool(tool_box)]
impl WazuhToolsServer {
fn new() -> Result<Self, anyhow::Error> {
dotenv().ok();
@ -214,16 +175,20 @@ impl WazuhToolsServer {
rule_tools,
stats_tools,
vulnerability_tools,
tool_router: Self::tool_router(),
})
}
}
#[tool_router]
impl WazuhToolsServer {
#[tool(
name = "get_wazuh_alert_summary",
description = "Retrieves a summary of Wazuh security alerts. Returns formatted alert information including ID, timestamp, and description."
)]
async fn get_wazuh_alert_summary(
&self,
#[tool(aggr)] params: GetAlertSummaryParams,
Parameters(params): Parameters<GetAlertSummaryParams>,
) -> Result<CallToolResult, McpError> {
self.alert_tools.get_wazuh_alert_summary(params).await
}
@ -234,7 +199,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_rules_summary(
&self,
#[tool(aggr)] params: GetRulesSummaryParams,
Parameters(params): Parameters<GetRulesSummaryParams>,
) -> Result<CallToolResult, McpError> {
self.rule_tools.get_wazuh_rules_summary(params).await
}
@ -245,7 +210,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_vulnerability_summary(
&self,
#[tool(aggr)] params: GetVulnerabilitiesSummaryParams,
Parameters(params): Parameters<GetVulnerabilitiesSummaryParams>,
) -> Result<CallToolResult, McpError> {
self.vulnerability_tools
.get_wazuh_vulnerability_summary(params)
@ -258,7 +223,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_critical_vulnerabilities(
&self,
#[tool(aggr)] params: GetCriticalVulnerabilitiesParams,
Parameters(params): Parameters<GetCriticalVulnerabilitiesParams>,
) -> Result<CallToolResult, McpError> {
self.vulnerability_tools
.get_wazuh_critical_vulnerabilities(params)
@ -271,7 +236,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_agents(
&self,
#[tool(aggr)] params: GetAgentsParams,
Parameters(params): Parameters<GetAgentsParams>,
) -> Result<CallToolResult, McpError> {
self.agent_tools.get_wazuh_agents(params).await
}
@ -282,7 +247,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_agent_processes(
&self,
#[tool(aggr)] params: GetAgentProcessesParams,
Parameters(params): Parameters<GetAgentProcessesParams>,
) -> Result<CallToolResult, McpError> {
self.agent_tools.get_wazuh_agent_processes(params).await
}
@ -293,7 +258,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_cluster_health(
&self,
#[tool(aggr)] params: GetClusterHealthParams,
Parameters(params): Parameters<GetClusterHealthParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_cluster_health(params).await
}
@ -304,7 +269,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_cluster_nodes(
&self,
#[tool(aggr)] params: GetClusterNodesParams,
Parameters(params): Parameters<GetClusterNodesParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_cluster_nodes(params).await
}
@ -315,7 +280,7 @@ impl WazuhToolsServer {
)]
async fn search_wazuh_manager_logs(
&self,
#[tool(aggr)] params: SearchManagerLogsParams,
Parameters(params): Parameters<SearchManagerLogsParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.search_wazuh_manager_logs(params).await
}
@ -326,7 +291,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_manager_error_logs(
&self,
#[tool(aggr)] params: GetManagerErrorLogsParams,
Parameters(params): Parameters<GetManagerErrorLogsParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_manager_error_logs(params).await
}
@ -337,7 +302,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_log_collector_stats(
&self,
#[tool(aggr)] params: GetLogCollectorStatsParams,
Parameters(params): Parameters<GetLogCollectorStatsParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_log_collector_stats(params).await
}
@ -348,7 +313,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_remoted_stats(
&self,
#[tool(aggr)] params: GetRemotedStatsParams,
Parameters(params): Parameters<GetRemotedStatsParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_remoted_stats(params).await
}
@ -359,7 +324,7 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_agent_ports(
&self,
#[tool(aggr)] params: GetAgentPortsParams,
Parameters(params): Parameters<GetAgentPortsParams>,
) -> Result<CallToolResult, McpError> {
self.agent_tools.get_wazuh_agent_ports(params).await
}
@ -370,64 +335,36 @@ impl WazuhToolsServer {
)]
async fn get_wazuh_weekly_stats(
&self,
#[tool(aggr)] params: GetWeeklyStatsParams,
Parameters(params): Parameters<GetWeeklyStatsParams>,
) -> Result<CallToolResult, McpError> {
self.stats_tools.get_wazuh_weekly_stats(params).await
}
}
#[tool(tool_box)]
#[tool_handler]
impl ServerHandler for WazuhToolsServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_prompts()
.enable_resources()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
protocol_version: ProtocolVersion::V_2025_06_18,
server_info: Implementation {
name: env!("CARGO_PKG_NAME").to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
..Default::default()
},
instructions: Some(
"This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis.\n\
Available tools:\n\
- 'get_wazuh_alert_summary': Retrieves a summary of Wazuh security alerts. \
Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100).\n\
- 'get_wazuh_rules_summary': Retrieves a summary of Wazuh security rules. \
Supports filtering by 'level', 'group', and 'filename' parameters, with 'limit' to control the number of rules returned (defaults to 100).\n\
- 'get_wazuh_vulnerability_summary': Retrieves a summary of Wazuh vulnerability detections for a specific agent. \
Requires an 'agent_id' parameter. This must be provided as a string, representing the numeric ID of the agent (e.g., \"0\", \"1\", \"12\", \"001\", \"012\"). The server will automatically format this string into a three-digit, zero-padded identifier. For instance, an input of \"0\" will be treated as \"000\", \"1\" as \"001\", and \"12\" as \"012\". Supports filtering by 'severity' and 'cve' parameters, with 'limit' to control the number of vulnerabilities returned (defaults to 100).\n\
- 'get_wazuh_critical_vulnerabilities': Retrieves only critical vulnerabilities for a specific agent. \
Requires an 'agent_id' parameter. This must be provided as a string, representing the numeric ID of the agent (e.g., \"0\", \"1\", \"12\", \"001\", \"012\"). The server will automatically format this string into a three-digit, zero-padded identifier. For instance, an input of \"0\" will be treated as \"000\", \"1\" as \"001\", and \"12\" as \"012\". Returns detailed information about vulnerabilities with 'Critical' severity level.\n\
- 'get_wazuh_running_agents': Retrieves a list of Wazuh agents with their current status and details. \
Supports filtering by 'status' (active, disconnected, pending, never_connected), 'name', 'ip', 'group', 'os_platform', and 'version' parameters, with 'limit' to control the number of agents returned (defaults to 100, status defaults to 'active').\n\
- 'get_wazuh_agent_processes': Retrieves a list of running processes for a specific Wazuh agent. \
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Supports 'limit' (default 100) and 'search' (to filter by process name or command line) parameters.\n\
- 'get_wazuh_agent_ports': Retrieves a list of open network ports for a specific Wazuh agent. \
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Supports 'limit' (default 100), 'protocol' (e.g., \"tcp\", \"udp\"), and 'state' (e.g., \"LISTENING\", \"ESTABLISHED\") parameters to filter the results. Note: State filtering is performed client-side by this server.\n\
The 'state' parameter filters results:
- If 'state' is 'LISTENING' (case-insensitive): Only ports explicitly in the 'LISTENING' state are returned. Ports with other states, no state, or an empty state string are filtered out.
- If 'state' is any other value (e.g., 'ESTABLISHED'): Ports that are *not* in the 'LISTENING' state are returned. This includes ports with other defined states (like 'ESTABLISHED', 'TIME_WAIT', etc.) and ports that have *no state* defined. Ports with an empty state string are always filtered out.
Note: State filtering is performed client-side by this server. \
- 'search_wazuh_manager_logs': Searches Wazuh manager logs. \
Optional parameters: 'limit' (default 100), 'offset' (default 0), 'level' (e.g., \"error\", \"info\"), 'tag' (e.g., \"wazuh-modulesd\"), 'search_term' (for free-text search in log descriptions).\n\
- 'get_wazuh_manager_error_logs': Retrieves Wazuh manager error logs. \
Optional parameter: 'limit' (default 100).\n\
- 'get_wazuh_log_collector_stats': Retrieves log collector statistics for a specific Wazuh agent. \
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Returns detailed information for 'global' and 'interval' periods, including start/end times, and for each log file: location, events processed, bytes, and target-specific drop counts.\n\
- 'get_wazuh_remoted_stats': Retrieves statistics from the Wazuh remoted daemon (manager-wide).\n\
- 'get_wazuh_weekly_stats': Retrieves weekly statistics from the Wazuh manager. Returns a JSON object detailing various metrics aggregated over the past week. No parameters required.\n\
- 'get_wazuh_cluster_health': Checks the health of the Wazuh cluster. Returns a textual summary of the cluster's health status (e.g., enabled, running, connected nodes). No parameters required.\n\
- 'get_wazuh_cluster_nodes': Retrieves a list of nodes in the Wazuh cluster. \
Optional parameters: 'limit' (max nodes, API default 500), 'offset' (default 0), 'node_type' (e.g., \"master\", \"worker\")."
"This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis."
.to_string(),
),
capabilities: ServerCapabilities::builder()
.enable_tools()
.build(),
}
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _args = Args::parse();
let args = Args::parse();
tracing_subscriber::fmt()
.with_env_filter(
@ -439,14 +376,53 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("Starting Wazuh MCP Server...");
// Create an instance of our Wazuh tools server
let server = WazuhToolsServer::new().expect("Error initializing Wazuh tools server");
match args.transport.as_str() {
"stdio" => {
tracing::info!("Using stdio transport");
let server = WazuhToolsServer::new().expect("Error initializing Wazuh tools server");
let service = server.serve(stdio()).await.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;
service.waiting().await?;
}
#[cfg(feature = "http")]
"http" => {
use axum::Router;
tracing::info!("Using stdio transport");
let service = server.serve(stdio()).await.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;
tracing::info!("Starting HTTP server on {}:{}", args.host, args.port);
let addr = format!("{}:{}", args.host, args.port);
let service = StreamableHttpService::new(
|| Ok(WazuhToolsServer::new().expect("Error initializing Wazuh tools server")),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);
let router = Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("Listening on http://{}/mcp", addr);
axum::serve(tcp_listener, router)
.with_graceful_shutdown(async {
tokio::signal::ctrl_c().await.unwrap();
tracing::info!("Received Ctrl-C, shutting down...");
})
.await?;
}
#[cfg(not(feature = "http"))]
"http" => {
anyhow::bail!(
"HTTP transport is not enabled. Rebuild with the 'http' feature: cargo build --features http"
);
}
_ => {
anyhow::bail!(
"Unknown transport: '{}'. Use 'stdio' or 'http'",
args.transport
);
}
}
service.waiting().await?;
Ok(())
}

View File

@ -8,7 +8,7 @@
use super::{ToolModule, ToolUtils};
use reqwest::StatusCode;
use rmcp::model::{CallToolResult, Content};
use rmcp::Error as McpError;
use rmcp::ErrorData as McpError;
use std::sync::Arc;
use tokio::sync::Mutex;
use wazuh_client::{AgentsClient, Port as WazuhPort, VulnerabilityClient};
@ -33,7 +33,7 @@ pub struct GetAgentsParams {
pub version: Option<String>,
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
#[derive(Debug, serde::Deserialize, rmcp::schemars::JsonSchema)]
pub struct GetAgentProcessesParams {
#[schemars(
description = "Agent ID to get processes for (required, e.g., \"0\", \"1\", \"001\")"
@ -45,7 +45,7 @@ pub struct GetAgentProcessesParams {
pub search: Option<String>,
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
#[derive(Debug, serde::Deserialize, rmcp::schemars::JsonSchema)]
pub struct GetAgentPortsParams {
#[schemars(
description = "Agent ID to get network ports for (required, e.g., \"001\", \"002\", \"003\")"

View File

@ -4,7 +4,7 @@
//! from the Wazuh Indexer.
use rmcp::{
Error as McpError,
ErrorData as McpError,
model::{CallToolResult, Content},
tool,
};
@ -36,7 +36,7 @@ impl AlertTools {
)]
pub async fn get_wazuh_alert_summary(
&self,
#[tool(aggr)] params: GetAlertSummaryParams,
params: GetAlertSummaryParams,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(300);

View File

@ -10,7 +10,7 @@ pub mod stats;
pub mod vulnerabilities;
use rmcp::model::{CallToolResult, Content};
use rmcp::Error as McpError;
use rmcp::ErrorData as McpError;
pub trait ToolModule {
fn format_error(component: &str, operation: &str, error: &dyn std::fmt::Display) -> String {

View File

@ -4,7 +4,7 @@
//! from the Wazuh Manager.
use rmcp::{
Error as McpError,
ErrorData as McpError,
model::{CallToolResult, Content},
tool,
};
@ -41,7 +41,7 @@ impl RuleTools {
)]
pub async fn get_wazuh_rules_summary(
&self,
#[tool(aggr)] params: GetRulesSummaryParams,
params: GetRulesSummaryParams,
) -> Result<CallToolResult, McpError> {
let limit = params.limit.unwrap_or(300);

View File

@ -6,7 +6,7 @@
use super::{ToolModule, ToolUtils};
use reqwest::StatusCode;
use rmcp::model::{CallToolResult, Content};
use rmcp::Error as McpError;
use rmcp::ErrorData as McpError;
use std::sync::Arc;
use tokio::sync::Mutex;
use wazuh_client::{ClusterClient, LogsClient};

View File

@ -4,7 +4,7 @@
//! from the Wazuh Manager.
use rmcp::{
Error as McpError,
ErrorData as McpError,
model::{CallToolResult, Content},
schemars,
};

395
tests/mcp_http_test.rs Normal file
View File

@ -0,0 +1,395 @@
//! Tests for MCP protocol communication via HTTP transport
//!
//! These tests verify the Streamable HTTP transport implementation.
//! Run with: cargo test --features http --test mcp_http_test
use std::process::{Child, Command, Stdio};
use std::time::Duration;
use tokio::time::sleep;
use serde_json::{json, Value};
struct McpHttpServer {
child: Child,
base_url: String,
#[allow(dead_code)]
port: u16,
}
impl McpHttpServer {
async fn start() -> Result<Self, Box<dyn std::error::Error>> {
// Find an available port
let port = portpicker::pick_unused_port().unwrap_or(18080);
let base_url = format!("http://127.0.0.1:{}", port);
let child = Command::new("cargo")
.args([
"run",
"--features", "http",
"--bin", "mcp-server-wazuh",
"--",
"--transport", "http",
"--host", "127.0.0.1",
"--port", &port.to_string(),
])
.env("WAZUH_API_HOST", "nonexistent.example.com")
.env("WAZUH_API_PORT", "9999")
.env("WAZUH_API_USER", "test")
.env("WAZUH_API_PASS", "test")
.env("WAZUH_INDEXER_HOST", "nonexistent.example.com")
.env("WAZUH_INDEXER_PORT", "8888")
.env("WAZUH_INDEXER_USER", "test")
.env("WAZUH_INDEXER_PASS", "test")
.env("VERIFY_SSL", "false")
.env("RUST_LOG", "error")
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
// Wait for server to start
let server = McpHttpServer { child, base_url: base_url.clone(), port };
// Poll until the server is ready (max 30 seconds for cargo build + startup)
let client = reqwest::Client::new();
let max_retries = 60;
for i in 0..max_retries {
sleep(Duration::from_millis(500)).await;
// Try to connect to the server - use both Accept headers
let result = client
.post(format!("{}/mcp", base_url))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.json(&json!({
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "startup-check", "version": "1.0.0"}
}
}))
.send()
.await;
if let Ok(resp) = result {
if resp.status().is_success() {
return Ok(server);
}
}
if i % 10 == 0 {
eprintln!("Waiting for HTTP server to start... (attempt {}/{})", i + 1, max_retries);
}
}
Err("Server failed to start within timeout".into())
}
fn url(&self) -> &str {
&self.base_url
}
}
impl Drop for McpHttpServer {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
struct McpHttpClient {
client: reqwest::Client,
base_url: String,
session_id: Option<String>,
}
impl McpHttpClient {
fn new(base_url: &str) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.to_string(),
session_id: None,
}
}
async fn send_request(&mut self, message: &Value) -> Result<Value, Box<dyn std::error::Error>> {
let mut request = self.client
.post(format!("{}/mcp", self.base_url))
.header("Content-Type", "application/json")
// MCP Streamable HTTP spec requires accepting both JSON and SSE
.header("Accept", "application/json, text/event-stream");
// Add session ID header if we have one
if let Some(session_id) = &self.session_id {
request = request.header("Mcp-Session-Id", session_id);
}
let response = request.json(message).send().await?;
// Check for error status
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(format!("HTTP error {}: {}", status, text).into());
}
// Extract session ID from response headers if present
if let Some(session_id) = response.headers().get("Mcp-Session-Id") {
self.session_id = Some(session_id.to_str()?.to_string());
}
let text = response.text().await?;
// Handle SSE format - the response is in SSE format with "data:" prefix
let json_str = text.lines()
.filter(|line| line.starts_with("data:"))
.map(|line| line.trim_start_matches("data:").trim())
.filter(|s| !s.is_empty())
.next()
.ok_or_else(|| format!("No JSON data found in SSE response: {}", text))?;
let response: Value = serde_json::from_str(json_str)?;
Ok(response)
}
async fn initialize(&mut self) -> Result<Value, Box<dyn std::error::Error>> {
let init_request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});
self.send_request(&init_request).await
}
async fn send_initialized_notification(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let notification = json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
// Notifications don't expect a response, but we send it anyway
let mut request = self.client
.post(format!("{}/mcp", self.base_url))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream");
if let Some(session_id) = &self.session_id {
request = request.header("Mcp-Session-Id", session_id);
}
let _ = request.json(&notification).send().await?;
Ok(())
}
async fn list_tools(&mut self) -> Result<Value, Box<dyn std::error::Error>> {
let request = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {}
});
self.send_request(&request).await
}
}
#[cfg(feature = "http")]
mod http_tests {
use super::*;
// Helper to check if we should skip HTTP tests
fn should_skip() -> bool {
std::env::var("SKIP_HTTP_TESTS").is_ok()
}
#[tokio::test]
async fn test_http_protocol_initialization() -> Result<(), Box<dyn std::error::Error>> {
if should_skip() {
eprintln!("Skipping HTTP test (SKIP_HTTP_TESTS is set)");
return Ok(());
}
let server = McpHttpServer::start().await?;
let mut client = McpHttpClient::new(server.url());
let response = client.initialize().await?;
// Verify JSON-RPC 2.0 compliance
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert!(response["result"].is_object(), "Expected result object, got: {:?}", response);
assert!(response["error"].is_null(), "Unexpected error: {:?}", response["error"]);
// Verify MCP initialize response structure
let result = &response["result"];
// Protocol version depends on rmcp library version - accept valid MCP versions
let protocol_version = result["protocolVersion"].as_str().unwrap();
assert!(
protocol_version.starts_with("202"),
"Expected MCP protocol version (e.g., 2025-03-26), got: {}",
protocol_version
);
assert!(result["capabilities"].is_object());
assert!(result["serverInfo"].is_object());
// Verify server info
let server_info = &result["serverInfo"];
assert!(server_info["name"].is_string());
assert!(server_info["version"].is_string());
// Verify session ID was set
assert!(client.session_id.is_some(), "Session ID should be set after initialization");
Ok(())
}
#[tokio::test]
async fn test_http_tools_list() -> Result<(), Box<dyn std::error::Error>> {
if should_skip() {
eprintln!("Skipping HTTP test (SKIP_HTTP_TESTS is set)");
return Ok(());
}
let server = McpHttpServer::start().await?;
let mut client = McpHttpClient::new(server.url());
// Initialize first
client.initialize().await?;
client.send_initialized_notification().await?;
// Request tools list
let response = client.list_tools().await?;
// Verify response structure
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 2);
assert!(response["result"].is_object(), "Expected result object, got: {:?}", response);
let result = &response["result"];
assert!(result["tools"].is_array(), "Expected tools array");
let tools = result["tools"].as_array().unwrap();
assert!(!tools.is_empty(), "Tools list should not be empty");
// Verify tool structure
for tool in tools {
assert!(tool["name"].is_string(), "Tool should have name");
assert!(tool["description"].is_string(), "Tool should have description");
assert!(tool["inputSchema"].is_object(), "Tool should have inputSchema");
}
Ok(())
}
#[tokio::test]
async fn test_http_session_management() -> Result<(), Box<dyn std::error::Error>> {
if should_skip() {
eprintln!("Skipping HTTP test (SKIP_HTTP_TESTS is set)");
return Ok(());
}
let server = McpHttpServer::start().await?;
let mut client = McpHttpClient::new(server.url());
// Initialize and get session ID
client.initialize().await?;
let session_id = client.session_id.clone();
assert!(session_id.is_some(), "Should receive session ID on init");
client.send_initialized_notification().await?;
// Subsequent requests should work with the same session
let response = client.list_tools().await?;
assert!(response["result"].is_object(), "Should get valid response with session");
Ok(())
}
#[tokio::test]
async fn test_http_multiple_requests() -> Result<(), Box<dyn std::error::Error>> {
if should_skip() {
eprintln!("Skipping HTTP test (SKIP_HTTP_TESTS is set)");
return Ok(());
}
let server = McpHttpServer::start().await?;
let mut client = McpHttpClient::new(server.url());
// Initialize
client.initialize().await?;
client.send_initialized_notification().await?;
// Send multiple requests
for i in 0..3 {
let request = json!({
"jsonrpc": "2.0",
"id": 10 + i,
"method": "tools/list",
"params": {}
});
let response = client.send_request(&request).await?;
assert_eq!(response["id"], 10 + i);
assert!(response["result"].is_object());
}
Ok(())
}
#[tokio::test]
async fn test_http_content_type_headers() -> Result<(), Box<dyn std::error::Error>> {
if should_skip() {
eprintln!("Skipping HTTP test (SKIP_HTTP_TESTS is set)");
return Ok(());
}
let server = McpHttpServer::start().await?;
let client = reqwest::Client::new();
// Test that server accepts application/json with both Accept types
let response = client
.post(format!("{}/mcp", server.url()))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.json(&json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test", "version": "1.0"}
}
}))
.send()
.await?;
assert!(response.status().is_success(), "Expected success status, got: {}", response.status());
Ok(())
}
}
// Port picker utility for tests
mod portpicker {
use std::net::TcpListener;
pub fn pick_unused_port() -> Option<u16> {
// Try to bind to port 0, which lets the OS assign an available port
TcpListener::bind("127.0.0.1:0")
.ok()
.and_then(|listener| listener.local_addr().ok())
.map(|addr| addr.port())
}
}

View File

@ -241,7 +241,7 @@ async fn test_invalid_json_rpc_request() -> Result<(), Box<dyn std::error::Error
#[tokio::test]
async fn test_unsupported_method() -> Result<(), Box<dyn std::error::Error>> {
let mut client = McpStdioClient::start()?;
sleep(Duration::from_millis(500)).await;
// Initialize first
@ -264,41 +264,30 @@ async fn test_unsupported_method() -> Result<(), Box<dyn std::error::Error>> {
});
client.send_message(&initialized_notification)?;
let unsupported_request = json!({
// In rmcp 0.10, unsupported methods are treated as custom notifications
// (no response expected). Send as a notification (no id).
let unsupported_notification = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "unsupported/method"
// Omitting "params": {} as it might be causing deserialization issues
// in rmcp for unknown methods. The JSON-RPC spec allows params to be omitted.
});
client.send_message(&unsupported_notification)?;
// Send the unsupported request. We don't expect a valid JSON-RPC response.
// Instead, the server is likely to close the connection due to deserialization issues
// in rmcp when encountering an unknown method, as it cannot match it to a known JsonRpcMessage variant.
client.send_message(&unsupported_request)?;
// Brief pause to let the server process the notification
sleep(Duration::from_millis(100)).await;
// Attempt to send a subsequent valid request to confirm the connection was dropped.
// Verify the server is still responsive after receiving the unsupported method
let list_tools_request = json!({
"jsonrpc": "2.0",
"id": 3, // Use a new ID
"id": 2,
"method": "tools/list",
"params": {}
});
let result = client.send_and_receive(&list_tools_request);
// Assert that the operation failed, indicating the connection was likely closed.
assert!(result.is_err(), "Server should have closed the connection after the unsupported method request, leading to an error here.");
// Optionally, check the error type more specifically if needed, e.g., for EOF.
if let Err(e) = result {
let error_message = e.to_string().to_lowercase();
assert!(
error_message.contains("eof") || error_message.contains("broken pipe") || error_message.contains("connection reset"),
"Expected EOF, broken pipe, or connection reset error, but got: {}", e
);
}
let tools_response = client.send_and_receive(&list_tools_request)?;
assert!(
tools_response.get("result").is_some(),
"Server should remain responsive after receiving unsupported method notification"
);
Ok(())
}