From f9efb70f196c2188c3dbc4dc2923d608fb5d5f2d Mon Sep 17 00:00:00 2001 From: Gianluca Brigandi Date: Thu, 22 May 2025 20:02:41 -0700 Subject: [PATCH] * Ported code to RMCP * Implemented unit and e2e testing * Other fixes and enhancements --- Cargo.toml | 11 +- Dockerfile | 2 +- README.md | 64 -- .../docker-compose-all-in-one.yml | 0 .../docker-compose.wazuh-demo.yml | 0 .../docker-compose.yml | 0 media/.DS_Store | Bin 0 -> 6148 bytes run.sh | 10 - src/http_service.rs | 135 ----- src/lib.rs | 4 - src/logging_utils.rs | 27 - src/main.rs | 121 ++-- src/mcp/client.rs | 425 -------------- src/mcp/mcp_server_core.rs | 456 --------------- src/mcp/mod.rs | 4 - src/mcp/protocol.rs | 127 ---- src/mcp/transform.rs | 235 -------- src/stdio_service.rs | 186 ------ src/wazuh/client.rs | 17 +- tests/README.md | 141 +++-- tests/e2e_client_test.rs | 194 ------- tests/integration_test.rs | 420 -------------- tests/mcp_client.rs | 180 ------ tests/mcp_client_cli.rs | 164 ------ tests/mcp_stdio_test.rs | 361 ++++++++++++ tests/mock_wazuh_server.rs | 340 +++++++++++ tests/rmcp_integration_test.rs | 546 ++++++++++++++++++ tests/run_tests.sh | 127 ++-- 28 files changed, 1519 insertions(+), 2778 deletions(-) rename docker-compose-all-in-one.yml => docker/docker-compose-all-in-one.yml (100%) rename docker-compose.wazuh-demo.yml => docker/docker-compose.wazuh-demo.yml (100%) rename docker-compose.yml => docker/docker-compose.yml (100%) create mode 100644 media/.DS_Store delete mode 100755 run.sh delete mode 100644 src/http_service.rs delete mode 100644 src/logging_utils.rs delete mode 100644 src/mcp/client.rs delete mode 100644 src/mcp/mcp_server_core.rs delete mode 100644 src/mcp/mod.rs delete mode 100644 src/mcp/protocol.rs delete mode 100644 src/mcp/transform.rs delete mode 100644 src/stdio_service.rs delete mode 100644 tests/e2e_client_test.rs delete mode 100644 tests/integration_test.rs delete mode 100644 tests/mcp_client.rs delete mode 100644 tests/mcp_client_cli.rs create mode 100644 tests/mcp_stdio_test.rs create mode 100644 tests/mock_wazuh_server.rs create mode 100644 tests/rmcp_integration_test.rs diff --git a/Cargo.toml b/Cargo.toml index d134ef8..0305a6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mcp-server-wazuh" -version = "0.1.0" +version = "0.1.2" edition = "2021" description = "Wazuh SIEM MCP Server" authors = ["Gianluca Brigandi "] @@ -21,6 +21,7 @@ schemars = "0.8" clap = { version = "4.5", features = ["derive"] } dotenv = "0.15" thiserror = "2.0" +chrono = "0.4.41" [dev-dependencies] mockito = "1.7" @@ -30,9 +31,7 @@ uuid = { version = "1.16", features = ["v4"] } once_cell = "1.21" async-trait = "0.1" regex = "1.11" - -# Test binaries are disabled for now due to dependency conflicts -# [[bin]] -# name = "mcp_client_cli" -# path = "tests/mcp_client_cli.rs" +tokio-test = "0.4" +serde_json = "1.0" +tempfile = "3.0" diff --git a/Dockerfile b/Dockerfile index 135515f..ca30cc1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ RUN apt-get update && \ apt-get install -y pkg-config libssl-dev && \ cargo build --release -FROM debian:bullseye-slim +FROM debian:bookworm-slim RUN apt-get update && \ apt-get install -y ca-certificates && \ diff --git a/README.md b/README.md index 4c5b551..c5107d0 100644 --- a/README.md +++ b/README.md @@ -169,16 +169,6 @@ This stdio interaction allows for tight integration with local development tools ``` If the HTTP server is enabled, it will start listening on the port specified by `MCP_SERVER_PORT` (default 8000). Otherwise, it will operate in stdio mode. -### Docker Deployment - -1. **Clone the repository** (if not already done). -2. **Configure:** Ensure you have a `.env` file with your Wazuh credentials in the project root if using the API, or set the environment variables directly in the `docker-compose.yml` or your deployment environment. -3. **Build and Run:** - ```bash - docker-compose up --build -d - ``` - This will build the Docker image and start the container in detached mode. - ## Stdio Mode Operation The server communicates via `stdin` and `stdout` using JSON-RPC 2.0 messages, adhering to the Model Context Protocol (MCP). @@ -348,60 +338,6 @@ Example interaction flow: } ``` -## Running the All-in-One Demo (Wazuh + MCP Server) - -For a complete local demo environment that includes Wazuh (Indexer, Manager, Dashboard) and the Wazuh MCP Server pre-configured to connect to it (for HTTP mode testing), you can use the `docker-compose.all-in-one.yml` file. - -This setup is ideal for testing the end-to-end flow from Wazuh alerts to MCP messages via the HTTP interface. - -**1. Launch the Environment:** - -Navigate to the project root directory in your terminal and run: - -```bash -docker-compose -f docker-compose.all-in-one.yml up -d -``` - -This command will: -- Download the necessary Wazuh and OpenSearch images (if not already present). -- Start the Wazuh Indexer, Wazuh Manager, and Wazuh Dashboard services. -- Build and start the Wazuh MCP Server (in HTTP mode). -- All services are configured to communicate with each other on an internal Docker network. - -**2. Accessing Services:** - -* **Wazuh Dashboard:** - * URL: `https://localhost:8443` (Note: Uses HTTPS with a self-signed certificate, so your browser will likely show a warning). - * Default Username: `admin` - * Default Password: `AdminPassword123!` (This is set by `WAZUH_INITIAL_PASSWORD` in the `wazuh-indexer` service). - -* **Wazuh MCP Server (HTTP Mode):** - * The MCP server will be running and accessible on port `8000` by default (or the port specified by `MCP_SERVER_PORT` if you've set it as an environment variable on your host machine before running docker-compose). - * Example MCP endpoint: `http://localhost:8000/mcp` - * Example Health endpoint: `http://localhost:8000/health` - * **Configuration:** The `mcp-server` service within `docker-compose.all-in-one.yml` is already configured with the necessary environment variables to connect to the `wazuh-manager` service: - * `WAZUH_HOST=wazuh-manager` - * `WAZUH_PORT=55000` - * `WAZUH_USER=wazuh_user_demo` - * `WAZUH_PASS=wazuh_password_demo` - * `VERIFY_SSL=false` - You do not need to set these in a separate `.env` file when using this all-in-one compose file, as they are defined directly in the service's environment. - -**3. Stopping the Environment:** - -To stop all services, run: - -```bash -docker-compose -f docker-compose.all-in-one.yml down -``` - -To stop and remove volumes (deleting Wazuh data): - -```bash -docker-compose -f docker-compose.all-in-one.yml down -v -``` -This approach simplifies setup by bundling all necessary components and their configurations for HTTP mode testing. - ## Development & Testing - **Code Style:** Uses standard Rust formatting (`cargo fmt`). diff --git a/docker-compose-all-in-one.yml b/docker/docker-compose-all-in-one.yml similarity index 100% rename from docker-compose-all-in-one.yml rename to docker/docker-compose-all-in-one.yml diff --git a/docker-compose.wazuh-demo.yml b/docker/docker-compose.wazuh-demo.yml similarity index 100% rename from docker-compose.wazuh-demo.yml rename to docker/docker-compose.wazuh-demo.yml diff --git a/docker-compose.yml b/docker/docker-compose.yml similarity index 100% rename from docker-compose.yml rename to docker/docker-compose.yml diff --git a/media/.DS_Store b/media/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..11845dfb108b31f5dd4718b1092462d9d0149e0b GIT binary patch literal 6148 zcmeHK!A=4(5Pd}s7%#+=$KE{Qg2H`6@TNvX^sK-JNDv9^YV@?9?we^v37U8^#LSSH z*G{L?_6?yFfa?~c7ElLJqY75`X}%E|7u}Ld^awi5F+l?lm?6V*E7}~p$bi)D5HGCS z8wz}_Uyc>4*}@4Hm}9)i2N%UEm%84dk5`W7tn(uh+~J8eFWAoLP0(YuoSNX4efchk zcZ?~0bB!FbHK2b=ulPo6$-E5D)EV_>Oe2;7^{0bIp8&)v-70L$TSDgeh-t(!AV(-J zOeKb?EF}A*~J%%)79zs8V;8RG;0hP1Dg!&g>GH$|MQ>U z|C^JnXABqv|B3-u?e%+I9x2|fTZfao)>H4PA`({yTu)&`uVTi^ReVHMq2Eh`m_{rE R(nGO70!o7!W8hC2cn74OOFaMp literal 0 HcmV?d00001 diff --git a/run.sh b/run.sh deleted file mode 100755 index c51422f..0000000 --- a/run.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -if [ ! -f .env ]; then - echo "No .env file found. Creating from .env.example..." - cp .env.example .env - echo "Please edit .env file with your configuration." - exit 1 -fi - -cargo run diff --git a/src/http_service.rs b/src/http_service.rs deleted file mode 100644 index cb39762..0000000 --- a/src/http_service.rs +++ /dev/null @@ -1,135 +0,0 @@ -use axum::{ - extract::State, - http::StatusCode, - response::{IntoResponse, Response}, - routing::{get, post}, - Json, Router, -}; -use serde_json::{json, Value}; -use std::sync::Arc; -use tracing::{debug, error, info}; - -use crate::logging_utils::{log_mcp_request, log_mcp_response}; -use crate::AppState; // Added - -pub fn create_http_router(app_state: Arc) -> Router { - Router::new() - .route("/health", get(health_check)) - .route("/mcp", get(get_mcp_data)) - .route("/mcp", post(post_mcp_data)) - .with_state(app_state) -} - -async fn health_check() -> impl IntoResponse { - Json(json!({ - "status": "ok", - "service": "wazuh-mcp-server", - "timestamp": chrono::Utc::now().to_rfc3339() - })) -} - -async fn get_mcp_data( - State(app_state): State>, -) -> Result>, ApiError> { - info!("Handling GET /mcp request"); - - let wazuh_client = app_state.wazuh_client.lock().await; - - match wazuh_client.get_alerts().await { - Ok(alerts) => { - // Transform Wazuh alerts to MCP messages - let mcp_messages = alerts - .iter() - .map(|alert| { - json!({ - "protocol_version": "1.0", - "source": "Wazuh", - "timestamp": chrono::Utc::now().to_rfc3339(), - "event_type": "alert", - "context": alert, - "metadata": { - "integration": "Wazuh-MCP" - } - }) - }) - .collect::>(); - - Ok(Json(mcp_messages)) - } - Err(e) => { - error!("Error getting alerts from Wazuh: {}", e); - Err(ApiError::InternalServerError(format!( - "Failed to get alerts from Wazuh: {}", - e - ))) - } - } -} - -async fn post_mcp_data( - State(app_state): State>, - Json(payload): Json, -) -> Result>, ApiError> { - info!("Handling POST /mcp request with payload"); - debug!("Payload: {:?}", payload); - - // Log the incoming payload - let request_str = serde_json::to_string(&payload).unwrap_or_else(|e| { - error!( - "Failed to serialize POST request payload for logging: {}", - e - ); - format!( - "{{\"error\":\"Failed to serialize request payload: {}\"}}", - e - ) - }); - log_mcp_request(&request_str); - - let result = get_mcp_data(State(app_state)).await; - - let response_str = match &result { - Ok(json_response) => serde_json::to_string(&json_response.0).unwrap_or_else(|e| { - error!("Failed to serialize POST response for logging: {}", e); - format!("{{\"error\":\"Failed to serialize response: {}\"}}", e) - }), - Err(api_error) => { - let error_json_surrogate = json!({ - "error": format!("{:?}", api_error) // Or a more structured error - }); - serde_json::to_string(&error_json_surrogate).unwrap_or_else(|e| { - error!("Failed to serialize POST error response for logging: {}", e); - format!( - "{{\"error\":\"Failed to serialize error response: {}\"}}", - e - ) - }) - } - }; - log_mcp_response(&response_str); - - result -} - -#[derive(Debug)] -pub enum ApiError { - BadRequest(String), - NotFound(String), - InternalServerError(String), -} - -impl IntoResponse for ApiError { - fn into_response(self) -> Response { - let (status, error_message) = match self { - ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), - ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg), - ApiError::InternalServerError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), - }; - - let body = Json(json!({ - "error": error_message - })); - - (status, body).into_response() - } -} diff --git a/src/lib.rs b/src/lib.rs index 91ed3f0..507b8d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,4 @@ -// This file is kept for compatibility with existing tests and binaries -// The main MCP server functionality has been moved to main.rs using the rmcp framework - pub mod wazuh; -// Re-export for backward compatibility pub use wazuh::client::WazuhIndexerClient; pub use wazuh::error::WazuhApiError; diff --git a/src/logging_utils.rs b/src/logging_utils.rs deleted file mode 100644 index 94e51e5..0000000 --- a/src/logging_utils.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::fs::OpenOptions; -use std::io::Write; -use tracing::error; - -const REQUEST_LOG_FILE: &str = "mcp_requests.log"; -const RESPONSE_LOG_FILE: &str = "mcp_responses.log"; - -fn log_to_file(filename: &str, message: &str) { - match OpenOptions::new().create(true).append(true).open(filename) { - Ok(mut file) => { - if let Err(e) = writeln!(file, "{}", message) { - error!("Failed to write to {}: {}", filename, e); - } - } - Err(e) => { - error!("Failed to open {} for appending: {}", filename, e); - } - } -} - -pub fn log_mcp_request(request_str: &str) { - log_to_file(REQUEST_LOG_FILE, request_str); -} - -pub fn log_mcp_response(response_str: &str) { - log_to_file(RESPONSE_LOG_FILE, response_str); -} diff --git a/src/main.rs b/src/main.rs index 70db428..3b507ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,7 +44,6 @@ use rmcp::{ schemars, tool, transport::stdio, }; -use serde_json::json; use std::sync::Arc; use std::env; use clap::Parser; @@ -93,12 +92,16 @@ impl WazuhToolsServer { .to_lowercase() == "true"; - let wazuh_client = WazuhIndexerClient::new( + let protocol = env::var("WAZUH_TEST_PROTOCOL").unwrap_or_else(|_| "https".to_string()); + tracing::debug!(?protocol, "Using Wazuh protocol for client from WAZUH_TEST_PROTOCOL or default"); + + let wazuh_client = WazuhIndexerClient::new_with_protocol( wazuh_host, wazuh_port, wazuh_user, wazuh_pass, verify_ssl, + &protocol, ); Ok(Self { @@ -121,72 +124,58 @@ impl WazuhToolsServer { match self.wazuh_client.get_alerts().await { Ok(raw_alerts) => { let alerts_to_process: Vec<_> = raw_alerts.into_iter().take(limit as usize).collect(); - - let content_items: Vec = if alerts_to_process.is_empty() { - // If no alerts, return a single "no alerts" message - vec![json!({ - "type": "text", - "text": "No Wazuh alerts found." - })] - } else { - // Map each alert to a content item - alerts_to_process - .into_iter() - .map(|alert| { - let source = alert.get("_source").unwrap_or(&alert); - - // Extract alert ID - let id = source.get("id") - .and_then(|v| v.as_str()) - .or_else(|| alert.get("_id").and_then(|v| v.as_str())) - .unwrap_or("Unknown ID"); - - // Extract rule description - let description = source.get("rule") - .and_then(|r| r.get("description")) - .and_then(|d| d.as_str()) - .unwrap_or("No description available"); - - // Extract timestamp if available - let timestamp = source.get("timestamp") - .and_then(|t| t.as_str()) - .unwrap_or("Unknown time"); - - // Extract agent name if available - let agent_name = source.get("agent") - .and_then(|a| a.get("name")) - .and_then(|n| n.as_str()) - .unwrap_or("Unknown agent"); - - // Extract rule level if available - let rule_level = source.get("rule") - .and_then(|r| r.get("level")) - .and_then(|l| l.as_u64()) - .unwrap_or(0); - - // Format the alert as a text entry and create a content item - json!({ - "type": "text", - "text": format!( - "Alert ID: {}\nTime: {}\nAgent: {}\nLevel: {}\nDescription: {}", - id, timestamp, agent_name, rule_level, description - ) - }) - }) - .collect() - }; - tracing::info!("Successfully processed {} alerts into content items", content_items.len()); + if alerts_to_process.is_empty() { + tracing::info!("No Wazuh alerts found to process. Returning standard message."); + // Ensure this directly returns a Vec with one Content::text item + return Ok(CallToolResult::success(vec![Content::text( + "No Wazuh alerts found.", + )])); + } - // Construct the final result with the content array containing multiple text objects - let result = json!({ - "content": content_items - }); - - Ok(CallToolResult::success(vec![ - Content::json(result) - .map_err(|e| McpError::internal_error(e.to_string(), None))?, - ])) + // Process non-empty alerts + // This part should already be correct if alerts_to_process is not empty, + // as it maps each alert to Content::text directly. + let num_alerts_to_process = alerts_to_process.len(); // Get length before moving + let mcp_content_items: Vec = alerts_to_process + .into_iter() + .map(|alert_value| { + let source = alert_value.get("_source").unwrap_or(&alert_value); + + let id = source.get("id") + .and_then(|v| v.as_str()) + .or_else(|| alert_value.get("_id").and_then(|v| v.as_str())) + .unwrap_or("Unknown ID"); + + let description = source.get("rule") + .and_then(|r| r.get("description")) + .and_then(|d| d.as_str()) + .unwrap_or("No description available"); + + let timestamp = source.get("timestamp") + .and_then(|t| t.as_str()) + .unwrap_or("Unknown time"); + + let agent_name = source.get("agent") + .and_then(|a| a.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or("Unknown agent"); + + let rule_level = source.get("rule") + .and_then(|r| r.get("level")) + .and_then(|l| l.as_u64()) + .unwrap_or(0); + + let formatted_text = format!( + "Alert ID: {}\nTime: {}\nAgent: {}\nLevel: {}\nDescription: {}", + id, timestamp, agent_name, rule_level, description + ); + Content::text(formatted_text) + }) + .collect(); + + tracing::info!("Successfully processed {} alerts into {} MCP content items", num_alerts_to_process, mcp_content_items.len()); + Ok(CallToolResult::success(mcp_content_items)) } Err(e) => { let err_msg = format!("Error retrieving alerts from Wazuh: {}", e); diff --git a/src/mcp/client.rs b/src/mcp/client.rs deleted file mode 100644 index 796735e..0000000 --- a/src/mcp/client.rs +++ /dev/null @@ -1,425 +0,0 @@ -use async_trait::async_trait; -use reqwest::Client; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Value; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::Duration; -use thiserror::Error; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, ChildStdin, ChildStdout, Command}; - -#[derive(Error, Debug)] -pub enum McpClientError { - #[error("HTTP request error: {0}")] - HttpRequestError(#[from] reqwest::Error), - - #[error("HTTP API error: status {status}, message: {message}")] - HttpApiError { - status: reqwest::StatusCode, - message: String, - }, - - #[error("JSON serialization/deserialization error: {0}")] - JsonError(#[from] serde_json::Error), - - #[error("IO error: {0}")] - IoError(#[from] std::io::Error), - - #[error("Failed to spawn child process: {0}")] - ProcessSpawnError(String), - - #[error("Child process stdin/stdout not available")] - ProcessPipeError, - - #[error("JSON-RPC error: code {code}, message: {message}, data: {data:?}")] - JsonRpcError { - code: i32, - message: String, - data: Option, - }, - - #[error("Received unexpected JSON-RPC response: {0}")] - UnexpectedResponse(String), - - #[error("Operation timed out")] - Timeout, - - #[error("Operation not supported in current mode: {0}")] - UnsupportedOperation(String), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct McpMessage { - pub protocol_version: String, - pub source: String, - pub timestamp: String, - pub event_type: String, - pub context: Value, - pub metadata: Value, -} - -#[derive(Serialize, Debug)] -struct JsonRpcRequest { - jsonrpc: String, - method: String, - params: Option, - id: Value, // Changed from usize to Value -} - -#[derive(Deserialize, Debug)] -struct JsonRpcResponse { - jsonrpc: String, - result: Option, - error: Option, - id: Value, // Changed from usize to Value -} - -#[derive(Deserialize, Debug)] -struct JsonRpcErrorData { - code: i32, - message: String, - data: Option, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct ServerInfo { - pub name: String, - pub version: String, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct InitializeResult { - pub protocol_version: String, - pub server_info: ServerInfo, -} - -#[async_trait] -pub trait McpClientTrait { - async fn initialize(&mut self) -> Result; - async fn provide_context( - &mut self, - params: Option, - ) -> Result, McpClientError>; - async fn shutdown(&mut self) -> Result<(), McpClientError>; -} - -enum ClientMode { - Http { - client: Client, - base_url: String, - }, - Stdio { - stdin: ChildStdin, - stdout: BufReader, - }, -} - -pub struct McpClient { - mode: ClientMode, - child_process: Option, - request_id_counter: AtomicUsize, -} - -#[async_trait] -impl McpClientTrait for McpClient { - async fn initialize(&mut self) -> Result { - match &mut self.mode { - ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation( - "initialize is not supported in HTTP mode".to_string(), - )), - ClientMode::Stdio { .. } => { - let request_id = self.next_id(); - self.send_stdio_request("initialize", None::<()>, request_id) - .await - } - } - } - - async fn provide_context( - &mut self, - params: Option, - ) -> Result, McpClientError> { - match &mut self.mode { - ClientMode::Http { client, base_url } => { - let url = format!("{}/mcp", base_url); - let request_builder = if let Some(p) = params { - client.post(&url).json(&p) - } else { - client.get(&url) - }; - let response = request_builder - .send() - .await - .map_err(McpClientError::HttpRequestError)?; - - if !response.status().is_success() { - let status = response.status(); - let message = response.text().await.unwrap_or_else(|_| { - format!("Failed to get error body for status {}", status) - }); - return Err(McpClientError::HttpApiError { status, message }); - } - response - .json::>() - .await - .map_err(McpClientError::HttpRequestError) - } - ClientMode::Stdio { .. } => { - let request_id = self.next_id(); - self.send_stdio_request("provideContext", params, request_id) - .await - } - } - } - - async fn shutdown(&mut self) -> Result<(), McpClientError> { - match &mut self.mode { - ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation( - "shutdown is not supported in HTTP mode".to_string(), - )), - ClientMode::Stdio { .. } => { - let request_id = self.next_id(); - // Attempt to send shutdown command, ignore error if server already closed pipe - let _result: Result, McpClientError> = self - .send_stdio_request("shutdown", None::<()>, request_id) - .await; - // Always try to clean up the process - self.close_stdio_process().await - } - } - } -} - -impl McpClient { - pub fn new_http(base_url: String) -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(30)) - .build() - .expect("Failed to create HTTP client"); - Self { - mode: ClientMode::Http { client, base_url }, - child_process: None, - request_id_counter: AtomicUsize::new(1), - } - } - - pub async fn new_stdio( - executable_path: &str, - envs: Option>, - ) -> Result { - let mut command = Command::new(executable_path); - command.stdin(std::process::Stdio::piped()); - command.stdout(std::process::Stdio::piped()); - command.stderr(std::process::Stdio::inherit()); // Pipe child's stderr to parent's stderr for visibility - - if let Some(env_vars) = envs { - for (key, value) in env_vars { - command.env(key, value); - } - } - - let mut child = command - .spawn() - .map_err(|e| McpClientError::ProcessSpawnError(e.to_string()))?; - - let stdin = child.stdin.take().ok_or(McpClientError::ProcessPipeError)?; - let stdout = child - .stdout - .take() - .ok_or(McpClientError::ProcessPipeError)?; - - Ok(Self { - mode: ClientMode::Stdio { - stdin, - stdout: BufReader::new(stdout), - }, - child_process: Some(child), - request_id_counter: AtomicUsize::new(1), - }) - } - - fn next_id(&self) -> Value { - Value::from(self.request_id_counter.fetch_add(1, Ordering::SeqCst)) - } - - async fn send_stdio_request( - &mut self, - method: &str, - params: Option

, - id: Value, // Added id parameter - ) -> Result { - // Removed: let request_id = self.next_id(); - let rpc_request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params, - id: id.clone(), // Use the provided id - }; - let request_json = serde_json::to_string(&rpc_request)? + "\n"; - - let (stdin, stdout) = match &mut self.mode { - ClientMode::Stdio { stdin, stdout } => (stdin, stdout), - ClientMode::Http { .. } => { - return Err(McpClientError::UnsupportedOperation( - "send_stdio_request is only for Stdio mode".to_string(), - )) - } - }; - - stdin.write_all(request_json.as_bytes()).await?; - stdin.flush().await?; - - let mut response_json = String::new(); - match tokio::time::timeout( - Duration::from_secs(10), - stdout.read_line(&mut response_json), - ) - .await - { - Ok(Ok(0)) => { - return Err(McpClientError::IoError(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "Server closed stdout", - ))) - } - Ok(Ok(_)) => { /* continue */ } - Ok(Err(e)) => return Err(McpClientError::IoError(e)), - Err(_) => return Err(McpClientError::Timeout), - } - - let rpc_response: JsonRpcResponse = serde_json::from_str(response_json.trim())?; - - // Compare Value IDs. Note: Value implements PartialEq. - if rpc_response.id != id { - return Err(McpClientError::UnexpectedResponse(format!( - "Mismatched request/response IDs. Expected {}, got {}. Response: '{}'", - id, rpc_response.id, response_json - ))); - } - - if let Some(err_data) = rpc_response.error { - return Err(McpClientError::JsonRpcError { - code: err_data.code, - message: err_data.message, - data: err_data.data, - }); - } - - rpc_response.result.ok_or_else(|| { - McpClientError::UnexpectedResponse("Missing result in JSON-RPC response".to_string()) - }) - } - - async fn close_stdio_process(&mut self) -> Result<(), McpClientError> { - if let Some(mut child) = self.child_process.take() { - child.kill().await.map_err(McpClientError::IoError)?; - let _ = child.wait().await; // Ensure process is reaped - } - Ok(()) - } - - // New public method for sending generic JSON-RPC requests - pub async fn send_json_rpc_request( - &mut self, - method: &str, - params: Option, - id: Value, - ) -> Result { - match &mut self.mode { - ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation( - "Generic JSON-RPC calls are not supported in HTTP mode by this client.".to_string(), - )), - ClientMode::Stdio { .. } => { - // R (result type) is Value for generic calls - self.send_stdio_request(method, params, id).await - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use httpmock::prelude::*; - use serde_json::json; - use tokio; - - #[tokio::test] - async fn test_mcp_client_http_get_data() { - // Renamed to be specific - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/mcp"); - then.status(200) - .header("content-type", "application/json") - .json_body(json!([ - { - "protocol_version": "1.0", - "source": "Wazuh", - "timestamp": "2023-05-01T12:00:00Z", - "event_type": "alert", - "context": { - "id": "12345", - "category": "intrusion_detection", - "severity": "high", - "description": "Test alert", - "data": { "source_ip": "192.168.1.100" } - }, - "metadata": { "integration": "Wazuh-MCP", "notes": "Test note" } - } - ])); - }); - - let mut client = McpClient::new_http(server.url("")); // Use new_http - - let result = client.provide_context(None).await.unwrap(); - - assert_eq!(result.len(), 1); - assert_eq!(result[0].protocol_version, "1.0"); - assert_eq!(result[0].source, "Wazuh"); - assert_eq!(result[0].event_type, "alert"); - - let context = &result[0].context; - assert_eq!(context["id"], "12345"); - assert_eq!(context["category"], "intrusion_detection"); - assert_eq!(context["severity"], "high"); - } - - #[tokio::test] - async fn test_mcp_client_http_health_check_equivalent() { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/health"); // Assuming /health is still the target for this test - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "status": "ok", - "service": "wazuh-mcp-server", - "timestamp": "2023-05-01T12:00:00Z" - })); - }); - - let client = McpClient::new_http(server.url("")); - - // The new trait doesn't have a direct "check_health". - // If `initialize` was to be used for HTTP health, it would be: - // let result = client.initialize().await; - // But initialize is Stdio-only. So this test needs to adapt or be removed - // if there's no direct equivalent in the new trait for HTTP health. - // For now, let's assume we might add a specific http_health method if needed, - // or this test is demonstrating a capability that's no longer directly on the trait. - // To make this test pass with current structure, we'd need a separate HTTP health method. - // Let's simulate calling the /health endpoint directly if that's the intent. - let http_client = reqwest::Client::new(); - let response = http_client.get(server.url("/health")).send().await.unwrap(); - assert_eq!(response.status(), reqwest::StatusCode::OK); - let health_data: Value = response.json().await.unwrap(); - assert_eq!(health_data["status"], "ok"); - assert_eq!(health_data["service"], "wazuh-mcp-server"); - } - - // TODO: Add tests for Stdio mode. This would require a mock executable - // or a more complex test setup. For now, focusing on the client structure. -} diff --git a/src/mcp/mcp_server_core.rs b/src/mcp/mcp_server_core.rs deleted file mode 100644 index 43200e7..0000000 --- a/src/mcp/mcp_server_core.rs +++ /dev/null @@ -1,456 +0,0 @@ -use serde_json::{json, Value}; -use std::sync::Arc; -use tracing::{debug, error, info}; - -use crate::mcp::protocol::{error_codes, JsonRpcError, JsonRpcRequest, JsonRpcResponse}; -use crate::AppState; - -#[derive(serde::Deserialize, Debug)] -struct ToolCallParams { - #[serde(rename = "name")] - name: String, - #[serde(rename = "arguments")] - arguments: Option, // Input parameters for the specific tool - #[serde(flatten)] - _extra: std::collections::HashMap, -} - -pub struct McpServerCore { - app_state: Arc, -} - -impl McpServerCore { - pub fn new(app_state: Arc) -> Self { - Self { app_state } - } - - pub async fn process_request(&self, request: JsonRpcRequest) -> String { - info!("Processing request: method={}", request.method); - - let response = match request.method.as_str() { - "initialize" => self.handle_initialize(request).await, - "shutdown" => self.handle_shutdown(request).await, - "provideContext" => self.handle_provide_context(request).await, - // Tool methods (prefix "tools/") - "tools/list" => self.handle_list_tools(request).await, - "tools/call" => self.handle_tool_call(request).await, // Use generic tool call handler - // "tools/wazuhAlerts" => self.handle_wazuh_alerts_tool(request).await, - // Resource methods (prefix "resources/") - "resources/list" => self.handle_get_resources(request).await, - "resources/read" => self.handle_read_resource(request).await, - // Prompt methods (prefix "prompts/") - "prompts/list" => self.handle_list_prompts(request).await, - _ => { - error!("Method not found: {}", request.method); - self.create_error_response( - error_codes::METHOD_NOT_FOUND, - format!("Method '{}' not found", request.method), - None, - request.id.clone(), - ) - } - }; - - response - } - - pub fn handle_parse_error(&self, error: serde_json::Error, raw_request: &str) -> String { - error!("Failed to parse JSON-RPC request: {}", error); - - // Try to extract the ID from the raw request if possible - let id = serde_json::from_str::(raw_request) - .and_then(|v| { - if let Some(id) = v.get("id") { - Ok(id.clone()) - } else { - // Use a different approach since custom is not available - Err(serde_json::Error::io(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "No ID field found", - ))) - } - }) - .unwrap_or(Value::Null); - - self.create_error_response( - error_codes::PARSE_ERROR, - format!("Parse error: {}", error), - None, - id, - ) - } - - async fn handle_initialize(&self, request: JsonRpcRequest) -> String { - debug!("Handling initialize request"); - - - let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition { - name: "wazuhAlertSummary".to_string(), - description: Some("Returns a text summary of all Wazuh alerts.".to_string()), - // Define a minimal valid input schema (empty object) - input_schema: Some(json!({ - "type": "object", - "properties": {} - })), - // No output schema needed as per requirements - output_schema: None, - }; - - let result = crate::mcp::protocol::InitializeResult { - protocol_version: "2024-11-05".to_string(), - capabilities: crate::mcp::protocol::Capabilities { - tools: crate::mcp::protocol::ToolCapability { - supported: true, - definitions: vec![wazuh_alert_summary_tool], // Only include wazuhAlertSummary tool - }, - resources: crate::mcp::protocol::SupportedFeature { supported: true }, - prompts: crate::mcp::protocol::SupportedFeature { supported: true }, - }, - server_info: crate::mcp::protocol::ServerInfo { - name: "Wazuh MCP Server".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }, - }; - - self.create_success_response(result, request.id) - } - - async fn handle_shutdown(&self, request: JsonRpcRequest) -> String { - debug!("Handling shutdown request"); - self.create_success_response(Value::Null, request.id) - } - - async fn handle_provide_context(&self, request: JsonRpcRequest) -> String { - debug!("Handling provideContext request"); - - let mut wazuh_client = self.app_state.wazuh_client.lock().await; - - match wazuh_client.get_alerts().await { - Ok(alerts) => { - let mcp_messages: Vec = alerts - .into_iter() - .map(|alert| crate::mcp::transform::transform_to_mcp(alert, "alert".to_string())) - .collect(); - - debug!("Transformed {} alerts into MCP messages for provideContext", mcp_messages.len()); - self.create_success_response(json!(mcp_messages), request.id) - } - Err(e) => { - error!("Error getting alerts from Wazuh for provideContext: {}", e); - self.create_error_response( - error_codes::INTERNAL_ERROR, - format!("Failed to get alerts from Wazuh: {}", e), - None, - request.id, - ) - } - } - } - - async fn handle_get_resources(&self, request: JsonRpcRequest) -> String { - debug!("Handling getResources request"); - // Return an empty list for now - let resources_result = crate::mcp::protocol::ResourcesListResult { - resources: vec![], - }; - - self.create_success_response(resources_result, request.id) - } - - async fn handle_read_resource(&self, request: JsonRpcRequest) -> String { - debug!("Handling readResource request: {:?}", request.params); - - #[derive(serde::Deserialize, Debug)] - struct ReadResourceParams { - uri: String, - // We can add _meta here if needed later - // _meta: Option, - } - - let params: ReadResourceParams = match request.params { - Some(params_value) => match serde_json::from_value(params_value) { - Ok(p) => p, - Err(e) => { - error!("Failed to parse params for resources/read: {}", e); - return self.create_error_response( - error_codes::INVALID_PARAMS, - format!("Invalid params for resources/read: {}", e), - None, - request.id, - ); - } - }, - None => { - error!("Missing params for resources/read"); - return self.create_error_response( - error_codes::INVALID_PARAMS, - "Missing params for resources/read, 'uri' is required".to_string(), - None, - request.id, - ); - } - }; - - error!("Unsupported URI for resources/read: {}", params.uri); - self.create_error_response( - error_codes::INVALID_PARAMS, - format!("Unsupported or unknown resource URI: {}", params.uri), - None, - request.id, - ) - } - - async fn handle_tool_call(&self, request: JsonRpcRequest) -> String { - debug!("Handling tools/call request: {:?}", request.params); - - - let params: ToolCallParams = match request.clone().params { - Some(params_value) => match serde_json::from_value(params_value) { - Ok(p) => p, - Err(e) => { - error!("Failed to parse params for tools/call: {}", e); - return self.create_error_response( - error_codes::INVALID_PARAMS, - format!("Invalid params for tools/call: {}", e), - None, - request.id, - ); - } - }, - None => { - error!("Missing params for tools/call"); - return self.create_error_response( - error_codes::INVALID_PARAMS, - "Missing params for tools/call, 'name' and 'arguments' are required".to_string(), - None, - request.id, - ); - } - }; - - match params.name.as_str() { - "wazuhAlertSummary" => { - info!("Dispatching tools/call to wazuhAlertSummary handler"); - self.handle_wazuh_alert_summary_tool(request).await - } - // wazuhAlerts tool is disabled but we keep the handler code - _ => { - error!("Unsupported tool name requested via tools/call: {}", params.name); - self.create_error_response( - error_codes::METHOD_NOT_FOUND, // Or a more specific tool error code if available - format!("Tool '{}' not found", params.name), - None, - request.id, - ) - } - } - } - - - async fn handle_list_tools(&self, request: JsonRpcRequest) -> String { - debug!("Handling tools/list request"); - - // Define the wazuhAlertSummary tool - let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition { - name: "wazuhAlertSummary".to_string(), - description: Some("Returns a text summary of all Wazuh alerts.".to_string()), - // Define a minimal valid input schema (empty object) - input_schema: Some(json!({ - "type": "object", - "properties": {} - })), - // No output schema needed as per requirements - output_schema: None, - }; - - let tools_list = crate::mcp::protocol::ToolsListResult { - tools: vec![wazuh_alert_summary_tool], - }; - - self.create_success_response(tools_list, request.id) - } - - - async fn handle_wazuh_alerts_tool(&self, request: JsonRpcRequest) -> String { - debug!("Handling tools/wazuhAlerts request. Params: {:?}", request.params); - - let wazuh_client = self.app_state.wazuh_client.lock().await; - - match wazuh_client.get_alerts().await { - Ok(raw_alerts) => { - let simplified_alerts: Vec = raw_alerts - .into_iter() - .map(|alert| { - let source = alert.get("_source").unwrap_or(&alert); - - // Extract ID: Try _source.id first, then _id - let id = source.get("id") - .and_then(|v| v.as_str()) - .or_else(|| alert.get("_id").and_then(|v| v.as_str())) - .unwrap_or("") // Default to empty string if not found - .to_string(); - - // Extract Description: Look in _source.rule.description - let description = source.get("rule") - .and_then(|r| r.get("description")) - .and_then(|d| d.as_str()) - .unwrap_or("") // Default to empty string if not found - .to_string(); - - json!({ - "id": id, - "description": description, - }) - }) - .collect(); - - debug!("Processed {} alerts into simplified format.", simplified_alerts.len()); - - // Construct the final result with the "alerts" array - let result = json!({ - "alerts": simplified_alerts, - "text": "Hello World", - }); - self.create_success_response(result, request.id) - } - Err(e) => { - error!("Error getting alerts from Wazuh for tools/wazuhAlerts: {}", e); - self.create_error_response( - error_codes::INTERNAL_ERROR, - format!("Failed to get alerts from Wazuh: {}", e), - None, - request.id, - ) - } - } - } - - async fn handle_wazuh_alert_summary_tool(&self, request: JsonRpcRequest) -> String { - debug!("Handling tools/wazuhAlertSummary request. Params: {:?}", request.params); - - let mut wazuh_client = self.app_state.wazuh_client.lock().await; - - - match wazuh_client.get_alerts().await { - Ok(raw_alerts) => { - // Create a content item for each alert - let content_items: Vec = if raw_alerts.is_empty() { - // If no alerts, return a single "no alerts" message - vec![json!({ - "type": "text", - "text": "No Wazuh alerts found." - })] - } else { - // Map each alert to a content item - raw_alerts - .into_iter() - .map(|alert| { - let source = alert.get("_source").unwrap_or(&alert); - - // Extract alert ID - let id = source.get("id") - .and_then(|v| v.as_str()) - .or_else(|| alert.get("_id").and_then(|v| v.as_str())) - .unwrap_or("Unknown ID"); - - // Extract rule description - let description = source.get("rule") - .and_then(|r| r.get("description")) - .and_then(|d| d.as_str()) - .unwrap_or("No description available"); - - // Extract timestamp if available - let timestamp = source.get("timestamp") - .and_then(|t| t.as_str()) - .unwrap_or("Unknown time"); - - // Format the alert as a text entry and create a content item - json!({ - "type": "text", - "text": format!("Alert ID: {}\nTime: {}\nDescription: {}", id, timestamp, description) - }) - }) - .collect() - }; - - debug!("Processed {} alerts into individual content items.", content_items.len()); - - // Construct the final result with the content array containing multiple text objects - let result = json!({ - "content": content_items - }); - - self.create_success_response(result, request.id) - } - Err(e) => { - error!("Error getting alerts from Wazuh for tools/wazuhAlertSummary: {}", e); - self.create_error_response( - error_codes::INTERNAL_ERROR, - format!("Failed to get alerts from Wazuh: {}", e), - None, - request.id, - ) - } - } - } - - async fn handle_list_prompts(&self, request: JsonRpcRequest) -> String { - debug!("Handling prompts/list request"); - - // Define the single prompt according to the new structure - let list_alerts_prompt = crate::mcp::protocol::PromptEntry { - name: "list-wazuh-alerts".to_string(), - description: Some("List the latest security alerts from Wazuh.".to_string()), - arguments: vec![], // This prompt takes no arguments - }; - - let prompts = vec![list_alerts_prompt]; - - let result = crate::mcp::protocol::PromptsListResult { prompts }; - - self.create_success_response(result, request.id) - } - - - fn create_success_response(&self, result: T, id: Value) -> String { - let response = JsonRpcResponse { - jsonrpc: "2.0".to_string(), - result: Some(result), - error: None, - id, - }; - - serde_json::to_string(&response).unwrap_or_else(|e| { - error!("Failed to serialize JSON-RPC response: {}", e); - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize response"}},"id":null}}"# - ) - }) - } - - pub(crate) fn create_error_response( - &self, - code: i32, - message: String, - data: Option, - id: Value, - ) -> String { - let response = JsonRpcResponse:: { - jsonrpc: "2.0".to_string(), - result: None, - error: Some(JsonRpcError { - code, - message, - data, - }), - id, - }; - - serde_json::to_string(&response).unwrap_or_else(|e| { - error!("Failed to serialize JSON-RPC error response: {}", e); - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize error response"}},"id":null}}"# - ) - }) - } -} diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs deleted file mode 100644 index b47c75a..0000000 --- a/src/mcp/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod transform; -pub mod client; -pub mod protocol; -pub mod mcp_server_core; diff --git a/src/mcp/protocol.rs b/src/mcp/protocol.rs deleted file mode 100644 index 08d2f3f..0000000 --- a/src/mcp/protocol.rs +++ /dev/null @@ -1,127 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Deserialize, Debug, Clone)] -pub struct JsonRpcRequest { - pub jsonrpc: String, - pub method: String, - pub params: Option, - pub id: Value, -} - -#[derive(Serialize, Debug)] -pub struct JsonRpcResponse { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, - pub id: Value, -} - -#[derive(Serialize, Debug)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Serialize, Debug, Clone)] -pub struct ToolDefinition { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(rename = "inputSchema", skip_serializing_if = "Option::is_none")] - pub input_schema: Option, // Added inputSchema - #[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")] - pub output_schema: Option, // Added outputSchema -} - -#[derive(Serialize, Debug)] -pub struct ToolCapability { - pub supported: bool, - #[serde(skip_serializing_if = "Vec::is_empty")] - pub definitions: Vec, // List available tools -} - -#[derive(Serialize, Debug)] -pub struct SupportedFeature { - pub supported: bool, -} - -#[derive(Serialize, Debug)] -pub struct Capabilities { - pub tools: ToolCapability, // Use the new structure - pub resources: SupportedFeature, - pub prompts: SupportedFeature, -} - -#[derive(Serialize, Debug)] -pub struct ServerInfo { - pub name: String, - pub version: String, -} - -#[derive(Serialize, Debug)] -pub struct InitializeResult { - #[serde(rename = "protocolVersion")] - pub protocol_version: String, - pub capabilities: Capabilities, - #[serde(rename = "serverInfo")] - pub server_info: ServerInfo, -} - -#[derive(Serialize, Debug)] -pub struct ResourceEntry { - pub uri: String, - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")] - pub mime_type: Option, -} - -#[derive(Serialize, Debug)] -pub struct ResourcesListResult { - pub resources: Vec, -} - -#[derive(Serialize, Debug)] -pub struct ToolsListResult { - pub tools: Vec, -} - -#[derive(Serialize, Debug, Clone)] -pub struct PromptArgument { - pub name: String, - pub required: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub default: Option, // Use Value for flexibility (string, bool, number) - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, // Optional description for the argument -} - -#[derive(Serialize, Debug, Clone)] -pub struct PromptEntry { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Vec::is_empty")] - pub arguments: Vec, -} - -#[derive(Serialize, Debug)] -pub struct PromptsListResult { - pub prompts: Vec, -} - -pub mod error_codes { - pub const PARSE_ERROR: i32 = -32700; - pub const INVALID_REQUEST: i32 = -32600; - pub const METHOD_NOT_FOUND: i32 = -32601; - pub const INVALID_PARAMS: i32 = -32602; - pub const INTERNAL_ERROR: i32 = -32603; - pub const SERVER_ERROR_START: i32 = -32000; - pub const SERVER_ERROR_END: i32 = -32099; -} diff --git a/src/mcp/transform.rs b/src/mcp/transform.rs deleted file mode 100644 index a4ebc7f..0000000 --- a/src/mcp/transform.rs +++ /dev/null @@ -1,235 +0,0 @@ -use chrono::{DateTime, Utc, SecondsFormat}; -use serde_json::{json, Value}; -use tracing::{debug, warn}; - -pub fn transform_to_mcp(event: Value, event_type: String) -> Value { - debug!(?event, %event_type, "Entering transform_to_mcp"); - - let source_obj = event.get("_source").unwrap_or(&event); - if event.get("_source").is_some() { - debug!("Event contains '_source' field, using it for transformation."); - } else { - debug!("Event does not contain '_source' field, using the event root for transformation."); - } - - let id = source_obj.get("id") - .and_then(|v| v.as_str()) - .or_else(|| event.get("_id").and_then(|v| v.as_str())) - .unwrap_or("unknown_id") - .to_string(); - debug!(%id, "Transformed: id"); - - let default_rule = json!({}); - let rule = source_obj.get("rule").unwrap_or(&default_rule); - if source_obj.get("rule").is_none() { - debug!("Transformed: rule (defaulted to empty object)"); - } else { - debug!(?rule, "Transformed: rule"); - } - let category = rule.get("groups") - .and_then(|g| g.as_array()) - .and_then(|arr| arr.first()) - .and_then(|v| v.as_str()) - .unwrap_or("unknown_category") - .to_string(); - debug!(%category, "Transformed: category"); - - let severity_level = rule.get("level").and_then(|v| v.as_u64()); - let severity = severity_level - .map(|level| match level { - 0..=3 => "low", - 4..=7 => "medium", - 8..=11 => "high", - _ => "critical", - }) - .unwrap_or("unknown_severity") - .to_string(); - debug!(?severity_level, %severity, "Transformed: severity"); - - let description = rule.get("description") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - debug!(%description, "Transformed: description"); - - let default_data = json!({}); - let data = source_obj.get("data").cloned().unwrap_or_else(|| { - debug!("Transformed: data (defaulted to empty object)"); - default_data.clone() - }); - if source_obj.get("data").is_some() { - debug!(?data, "Transformed: data"); - } - - - let default_agent = json!({}); - let agent = source_obj.get("agent").cloned().unwrap_or_else(|| { - debug!("Transformed: agent (defaulted to empty object)"); - default_agent.clone() - }); - if source_obj.get("agent").is_some() { - debug!(?agent, "Transformed: agent"); - } - - - let timestamp_str = source_obj.get("timestamp") - .and_then(|v| v.as_str()) - .unwrap_or(""); - debug!(%timestamp_str, "Attempting to parse timestamp"); - - let timestamp = DateTime::parse_from_rfc3339(timestamp_str) - .map(|dt| dt.with_timezone(&Utc)) - .or_else(|_| DateTime::parse_from_str(timestamp_str, "%Y-%m-%dT%H:%M:%S%.fZ").map(|dt| dt.with_timezone(&Utc))) - .unwrap_or_else(|_| { - warn!("Failed to parse timestamp '{}' for alert ID '{}'. Using current time.", timestamp_str, id); - Utc::now() - }); - debug!(%timestamp, "Transformed: timestamp"); - - let notes = "Data fetched via Wazuh API".to_string(); - debug!(%notes, "Transformed: notes"); - - let mcp_message = json!({ - "protocolVersion": "1.0", // Match initialize response - "source": "Wazuh", - "timestamp": timestamp.to_rfc3339_opts(SecondsFormat::Secs, true), - "event_type": event_type, - "context": { - "id": id, - "category": category, - "severity": severity, - "description": description, - "agent": agent, - "data": data - }, - "metadata": { - "integration": "Wazuh-MCP", - "notes": notes - } - }); - debug!(?mcp_message, "Exiting transform_to_mcp with result"); - mcp_message -} - -#[cfg(test)] -mod tests { - use super::*; - use chrono::TimeZone; - use serde_json::json; - - #[test] - fn test_transform_to_mcp_basic() { - let event_time_str = "2023-10-27T10:30:00.123Z"; - let event_time = Utc.datetime_from_str(event_time_str, "%Y-%m-%dT%H:%M:%S%.fZ").unwrap(); - - let event = json!({ - "id": "alert1", - "_id": "wazuh_alert_id_1", - "timestamp": event_time_str, - "rule": { - "level": 10, - "description": "High severity rule triggered", - "id": "1002", - "groups": ["gdpr", "pci_dss", "intrusion_detection"] - }, - "agent": { - "id": "001", - "name": "server-db" - }, - "data": { - "srcip": "1.2.3.4", - "dstport": "22" - } - }); - - let result = transform_to_mcp(event.clone(), "alert".to_string()); - - assert_eq!(result["protocol_version"], "1.0"); - assert_eq!(result["source"], "Wazuh"); - assert_eq!(result["event_type"], "alert"); - assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true)); - - let context = &result["context"]; - assert_eq!(context["id"], "alert1"); - assert_eq!(context["category"], "gdpr"); - assert_eq!(context["severity"], "high"); - assert_eq!(context["description"], "High severity rule triggered"); - assert_eq!(context["agent"]["name"], "server-db"); - assert_eq!(context["data"]["srcip"], "1.2.3.4"); - - let metadata = &result["metadata"]; - assert_eq!(metadata["integration"], "Wazuh-MCP"); - assert_eq!(metadata["notes"], "Data fetched via Wazuh API"); - } - - #[test] - fn test_transform_to_mcp_with_source_nesting() { - let event_time_str = "2023-10-27T11:00:00Z"; - let event_time = DateTime::parse_from_rfc3339(event_time_str).unwrap().with_timezone(&Utc); - - let event = json!({ - "_index": "wazuh-alerts-4.x-2023.10.27", - "_id": "alert_source_nested", - "_source": { - "id": "nested_alert_id", - "timestamp": event_time_str, - "rule": { - "level": 5, - "description": "Medium severity rule", - "groups": ["system_audit"] - }, - "agent": { "id": "002", "name": "web-server" }, - "data": { "command": "useradd test" } - } - }); - - let result = transform_to_mcp(event.clone(), "alert".to_string()); - assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true)); - let context = &result["context"]; - assert_eq!(context["id"], "nested_alert_id"); - assert_eq!(context["category"], "system_audit"); - assert_eq!(context["severity"], "medium"); - assert_eq!(context["description"], "Medium severity rule"); - assert_eq!(context["agent"]["name"], "web-server"); - assert_eq!(context["data"]["command"], "useradd test"); - } - - - #[test] - fn test_transform_to_mcp_with_defaults() { - let event = json!({}); - let before_transform = Utc::now(); - let result = transform_to_mcp(event, "alert".to_string()); - let after_transform = Utc::now(); - - assert_eq!(result["context"]["id"], "unknown_id"); - assert_eq!(result["context"]["category"], "unknown_category"); - assert_eq!(result["context"]["severity"], "unknown_severity"); - assert_eq!(result["context"]["description"], ""); - assert!(result["context"]["data"].is_object()); - assert!(result["context"]["agent"].is_object()); - assert_eq!(result["metadata"]["notes"], "Data fetched via Wazuh API"); - - let result_ts_str = result["timestamp"].as_str().unwrap(); - let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc); - assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp()); - } - - #[test] - fn test_transform_timestamp_parsing_fallback() { - let event = json!({ - "id": "ts_test", - "timestamp": "invalid-timestamp-format", - "rule": { "level": 3 }, - }); - let before_transform = Utc::now(); - let result = transform_to_mcp(event, "alert".to_string()); - let after_transform = Utc::now(); - - let result_ts_str = result["timestamp"].as_str().unwrap(); - let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc); - assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp()); - assert_eq!(result["context"]["id"], "ts_test"); - assert_eq!(result["context"]["severity"], "low"); - } -} diff --git a/src/stdio_service.rs b/src/stdio_service.rs deleted file mode 100644 index d770b48..0000000 --- a/src/stdio_service.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::oneshot::Sender as OneshotSender; -use tracing::{debug, error, info}; - -use crate::logging_utils::{log_mcp_request, log_mcp_response}; -use crate::mcp::mcp_server_core::McpServerCore; -use crate::mcp::protocol::{error_codes, JsonRpcRequest}; -use crate::AppState; -use serde_json::Value; - -pub async fn run_stdio_service(app_state: Arc, shutdown_tx: OneshotSender<()>) { - info!("Starting MCP server in stdio mode..."); - let mut stdin_reader = BufReader::new(tokio::io::stdin()); - let mut stdout_writer = tokio::io::stdout(); - let mcp_core = McpServerCore::new(app_state); - - let mut line_buffer = String::new(); - - debug!("run_stdio_service: Initialized readers/writers. Entering main loop."); - - loop { - debug!("stdio_service: Top of the loop. Clearing line buffer."); - line_buffer.clear(); - debug!("stdio_service: About to read_line from stdin."); - - let read_result = stdin_reader.read_line(&mut line_buffer).await; - debug!(?read_result, "stdio_service: read_line completed."); - - match read_result { - Ok(0) => { - debug!("stdio_service: read_line returned Ok(0) (EOF)."); - info!("Stdin closed (EOF), signaling shutdown and exiting stdio mode."); - let _ = shutdown_tx.send(()); // Signal main to shutdown Axum - debug!("stdio_service read 0 bytes, breaking loop."); - break; // EOF - } - Ok(bytes_read) => { - debug!(%bytes_read, "stdio_service: read_line returned Ok(bytes_read)."); - let request_str = line_buffer.trim(); - if request_str.is_empty() { - debug!("Received empty line from stdin, continuing."); - continue; - } - info!("Received from stdin (stdio_service): {}", request_str); - log_mcp_request(request_str); // Log the raw incoming string - - let parsed_value: Value = match serde_json::from_str(request_str) { - Ok(v) => v, - Err(e) => { - error!("JSON Parse Error: {}", e); - let response_json = mcp_core.handle_parse_error(e, request_str); - log_mcp_response(&response_json); - info!("Sending parse error response to stdout: {}", response_json); - let response_to_send = format!("{}\n", response_json); - if let Err(write_err) = - stdout_writer.write_all(response_to_send.as_bytes()).await - { - error!( - "Error writing parse error response to stdout: {}", - write_err - ); - let _ = shutdown_tx.send(()); - break; - } - if let Err(flush_err) = stdout_writer.flush().await { - error!("Error flushing stdout for parse error: {}", flush_err); - let _ = shutdown_tx.send(()); - break; - } - continue; - } - }; - - if parsed_value.get("id").is_none() - || parsed_value.get("id").map_or(false, |id| id.is_null()) - { - // --- Handle Notification (No ID or ID is null) --- - let method = parsed_value - .get("method") - .and_then(Value::as_str) - .unwrap_or(""); - info!("Received Notification: method='{}'", method); - - match method { - "notifications/initialized" => { - debug!("Client 'initialized' notification received. No action taken, no response sent."); - } - "exit" => { - info!("'exit' notification received. Signaling shutdown immediately."); - let _ = shutdown_tx.send(()); - return; - } - _ => { - debug!( - "Received unknown/unhandled notification method: '{}'. Ignoring.", - method - ); - } - } - continue; - } else { - let request_id = parsed_value.get("id").cloned().unwrap(); // We know ID exists and is not null here - - match serde_json::from_value::(parsed_value) { - Ok(rpc_request) => { - // --- Successfully parsed a Request --- - let is_shutdown = rpc_request.method == "shutdown"; - let response_json = mcp_core.process_request(rpc_request).await; - - // Log and send the response - log_mcp_response(&response_json); - info!("Sending response to stdout: {}", response_json); - let response_to_send = format!("{}\n", response_json); - - if let Err(e) = - stdout_writer.write_all(response_to_send.as_bytes()).await - { - error!("Error writing response to stdout: {}", e); - let _ = shutdown_tx.send(()); - break; - } - if let Err(e) = stdout_writer.flush().await { - error!("Error flushing stdout: {}", e); - let _ = shutdown_tx.send(()); - break; - } - - // Handle shutdown *after* sending the response - if is_shutdown { - debug!("'shutdown' request processed successfully. Signaling shutdown."); - let _ = shutdown_tx.send(()); // Signal main to shutdown Axum - return; // Exit the service loop - } - } - Err(e) => { - error!("Invalid JSON-RPC Request structure: {}", e); - // Use the ID we extracted earlier - let response_json = mcp_core.create_error_response( - error_codes::INVALID_REQUEST, - format!("Invalid Request structure: {}", e), - None, - request_id, // Use the ID from the original request - ); - - log_mcp_response(&response_json); - info!( - "Sending invalid request error response to stdout: {}", - response_json - ); - let response_to_send = format!("{}\n", response_json); - if let Err(write_err) = - stdout_writer.write_all(response_to_send.as_bytes()).await - { - error!( - "Error writing invalid request error response to stdout: {}", - write_err - ); - let _ = shutdown_tx.send(()); - break; - } - if let Err(flush_err) = stdout_writer.flush().await { - error!( - "Error flushing stdout for invalid request error: {}", - flush_err - ); - let _ = shutdown_tx.send(()); - break; - } - } - } - } - } - Err(e) => { - debug!(error = %e, "stdio_service: read_line returned Err."); - error!("Error reading from stdin for stdio_service: {}", e); - debug!("Signaling shutdown and breaking loop due to stdin read error."); - let _ = shutdown_tx.send(()); // Signal main to shutdown Axum - break; - } - } - debug!("stdio_service: Bottom of the loop, before next iteration."); - } - - info!("run_stdio_service: Exited main loop. stdio_service task is finishing."); -} diff --git a/src/wazuh/client.rs b/src/wazuh/client.rs index dd745c6..264e5ee 100644 --- a/src/wazuh/client.rs +++ b/src/wazuh/client.rs @@ -13,8 +13,8 @@ pub struct WazuhIndexerClient { http_client: Client, } -// Renamed impl block impl WazuhIndexerClient { + #[allow(dead_code)] pub fn new( host: String, indexer_port: u16, @@ -22,9 +22,20 @@ impl WazuhIndexerClient { password: String, verify_ssl: bool, ) -> Self { - debug!(%host, indexer_port, %username, %verify_ssl, "Creating new WazuhIndexerClient"); + Self::new_with_protocol(host, indexer_port, username, password, verify_ssl, "https") + } + + pub fn new_with_protocol( + host: String, + indexer_port: u16, + username: String, + password: String, + verify_ssl: bool, + protocol: &str, + ) -> Self { + debug!(%host, indexer_port, %username, %verify_ssl, %protocol, "Creating new WazuhIndexerClient"); // Base URL now points to the Indexer - let base_url = format!("https://{}:{}", host, indexer_port); + let base_url = format!("{}://{}:{}", protocol, host, indexer_port); debug!(%base_url, "Wazuh Indexer base URL set"); let http_client = Client::builder() diff --git a/tests/README.md b/tests/README.md index 68553a3..6f1eb6c 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,60 +1,133 @@ # Wazuh MCP Server Tests -This directory contains tests for the Wazuh MCP Server, including end-to-end tests that simulate a client interacting with the server. +This directory contains tests for the Wazuh MCP Server using the rmcp framework, including unit tests, integration tests with mock Wazuh API, and end-to-end MCP protocol tests. ## Test Files -- `e2e_client_test.rs`: End-to-end test for MCP client interacting with Wazuh MCP server -- `integration_test.rs`: Integration test for Wazuh MCP Server with a mock Wazuh API -- `mcp_client.rs`: Reusable MCP client implementation -- `mcp_client_cli.rs`: Command-line tool for interacting with the MCP server +- `rmcp_integration_test.rs`: Integration tests for the rmcp-based MCP server using a mock Wazuh API. +- `mock_wazuh_server.rs`: Mock Wazuh API server implementation, used by the integration tests. +- `mcp_stdio_test.rs`: Tests for MCP protocol communication via stdio, focusing on initialization, compliance, concurrent requests, and error handling for invalid/unsupported messages. +- `run_tests.sh`: A shell script that automates running the various test suites. + +## Testing Strategy + +### 1. Mock Wazuh Server Tests +Tests the MCP server with a mock Wazuh API to verify: +- Tool registration and schema generation +- Alert retrieval and formatting +- Error handling for API failures +- Parameter validation + +### 2. MCP Protocol Tests +Tests the MCP protocol implementation (primarily in `mcp_stdio_test.rs`): +- Initialize handshake. +- Tools listing (basic, without requiring a live Wazuh connection). +- Handling of invalid JSON-RPC requests and unsupported methods. +- Behavior with concurrent requests. +- JSON-RPC 2.0 compliance. +(Note: Full tool execution, like `tools/call`, is primarily tested in `rmcp_integration_test.rs` using the mock Wazuh server.) + +### 3. Unit Tests +Tests individual components and modules, typically run via `cargo test --lib`. These may include: +- Wazuh client logic (e.g., authentication, request formation, response parsing). +- Alert data transformation and formatting. +- Internal error handling mechanisms and utility functions. ## Running the Tests -To run all tests: - +### Run All Tests ```bash cargo test ``` -To run a specific test: - +### Run Specific Test Categories ```bash -cargo test --test e2e_client_test -cargo test --test integration_test +# Integration tests with mock Wazuh +cargo test --test rmcp_integration_test + +# MCP protocol tests +cargo test --test mcp_stdio_test + +# Unit tests +cargo test --lib ``` -## Using the MCP Client CLI - -The MCP Client CLI can be used to interact with the MCP server for testing purposes: - +### Run Tests with Logging ```bash -# Build the CLI -cargo build --bin mcp_client_cli - -# Run the CLI -MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli get-data -MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli health -MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli query '{"severity": "high"}' +RUST_LOG=debug cargo test -- --nocapture ``` ## Test Environment Variables -The tests use the following environment variables: +The tests support the following environment variables: -- `MCP_SERVER_URL`: URL of the MCP server (default: http://localhost:8000) -- `WAZUH_HOST`: Hostname of the Wazuh API server -- `WAZUH_PORT`: Port of the Wazuh API server -- `WAZUH_USER`: Username for Wazuh API authentication -- `WAZUH_PASS`: Password for Wazuh API authentication -- `VERIFY_SSL`: Whether to verify SSL certificates (default: false) -- `RUST_LOG`: Log level for the tests (default: info) +- `RUST_LOG`: Log level for tests (default: info) +- `TEST_WAZUH_HOST`: Real Wazuh host for integration tests (optional) +- `TEST_WAZUH_PORT`: Real Wazuh port for integration tests (optional) +- `TEST_WAZUH_USER`: Real Wazuh username for integration tests (optional) +- `TEST_WAZUH_PASS`: Real Wazuh password for integration tests (optional) ## Mock Wazuh API Server -The tests use a mock Wazuh API server to simulate the Wazuh API. The mock server provides: +The mock server simulates a real Wazuh Indexer API with: -- Authentication endpoint: `/security/user/authenticate` -- Alerts endpoint: `/wazuh-alerts-*_search` +### Authentication Endpoint +- `POST /security/user/authenticate` +- Returns mock JWT token -The mock server returns predefined responses for these endpoints, allowing the tests to run without a real Wazuh API server. +### Alerts Endpoint +- `POST /wazuh-alerts-*/_search` (Note: The Wazuh API typically uses POST for search queries with a body) +- Returns configurable mock alert data +- Supports different scenarios (success, empty, error) + +### Configurable Responses +The mock server can be configured to return: +- Successful responses with sample alerts +- Empty responses (no alerts) +- Error responses (500, 401, etc.) +- Malformed responses for error testing + +## Testing Without Real Wazuh + +All tests can run without a real Wazuh instance by using the mock server. This allows for: + +- **CI/CD Integration**: Tests run in any environment +- **Deterministic Results**: Predictable test data +- **Error Scenario Testing**: Simulate various failure modes +- **Fast Execution**: No network dependencies + +## Testing With a Real Wazuh Instance (Manual End-to-End) + +The automated test suites (`cargo test`) use mock servers or no Wazuh connection. To perform end-to-end testing with a real Wazuh instance, you need to run the server application itself and interact with it manually or via a separate client. + +1. **Set up your Wazuh environment:** Ensure you have a running Wazuh instance (Indexer/API). +2. **Configure Environment Variables:** Set the standard runtime environment variables for the server to connect to your Wazuh instance: + ```bash + export WAZUH_HOST="your-wazuh-indexer-host" # e.g., localhost or an IP address + export WAZUH_PORT="9200" # Or your Wazuh Indexer port + export WAZUH_USER="your-wazuh-api-user" + export WAZUH_PASS="your-wazuh-api-password" + export VERIFY_SSL="false" # Set to "true" if your Wazuh API uses a valid CA-signed SSL certificate + # export RUST_LOG="debug" # For more detailed server logs + ``` + +## Manual Testing + +### Using stdio directly +The server communicates over stdin/stdout. You can send commands by piping them to the process: +```bash +# Example: Send an initialize request +echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}' | cargo run --bin mcp-server-wazuh +``` + +### Using the test script +```bash +# Run the provided test script +./tests/run_tests.sh +``` + +This script will: +1. Start the MCP server with mock Wazuh configuration +2. Send a series of MCP commands +3. Verify responses +4. Clean up processes diff --git a/tests/e2e_client_test.rs b/tests/e2e_client_test.rs deleted file mode 100644 index b942672..0000000 --- a/tests/e2e_client_test.rs +++ /dev/null @@ -1,194 +0,0 @@ -use anyhow::Result; -use httpmock::prelude::*; -use reqwest::Client; -use serde_json::{json, Value}; -use std::process::{Child, Command}; -use std::time::Duration; -use tokio::time::sleep; -use uuid::Uuid; - -struct MockWazuhServer { - server: MockServer, -} - -impl MockWazuhServer { - fn new() -> Self { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(POST) - .path("/security/user/authenticate") - .header("Authorization", "Basic YWRtaW46YWRtaW4="); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "jwt": "mock.jwt.token" - })); - }); - - server.mock(|when, then| { - when.method(GET) - .path("/wazuh-alerts-*_search") - .header("Authorization", "Bearer mock.jwt.token"); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "hits": { - "hits": [ - { - "_source": { - "id": "12345", - "category": "intrusion_detection", - "severity": "high", - "description": "Possible intrusion attempt detected", - "data": { - "source_ip": "192.168.1.100", - "destination_ip": "10.0.0.1", - "port": 22 - }, - "notes": "Test alert" - } - }, - { - "_source": { - "id": "67890", - "category": "malware", - "severity": "critical", - "description": "Malware detected on system", - "data": { - "file_path": "/tmp/malicious.exe", - "hash": "abcdef123456", - "signature": "EICAR-Test-File" - } - } - } - ] - } - })); - }); - - Self { server } - } - - fn url(&self) -> String { - self.server.url("") - } -} - -struct McpClient { - client: Client, - base_url: String, -} - -impl McpClient { - fn new(base_url: String) -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(10)) - .build() - .expect("Failed to create HTTP client"); - - Self { client, base_url } - } - - async fn get_mcp_data(&self) -> Result> { - let url = format!("{}/mcp", self.base_url); - let response = self.client.get(&url).send().await?; - - if !response.status().is_success() { - anyhow::bail!("MCP request failed with status: {}", response.status()); - } - - let data = response.json::>().await?; - Ok(data) - } - - async fn check_health(&self) -> Result { - let url = format!("{}/health", self.base_url); - let response = self.client.get(&url).send().await?; - - if !response.status().is_success() { - anyhow::bail!("Health check failed with status: {}", response.status()); - } - - let data = response.json::().await?; - Ok(data) - } -} - -fn start_mcp_server(wazuh_url: &str, port: u16) -> Child { - let server_id = Uuid::new_v4().to_string(); - let wazuh_host_port: Vec<&str> = wazuh_url.trim_start_matches("http://").split(':').collect(); - let wazuh_host = wazuh_host_port[0]; - let wazuh_port = wazuh_host_port[1]; - - Command::new("cargo") - .args(["run", "--"]) - .env("WAZUH_HOST", wazuh_host) - .env("WAZUH_PORT", wazuh_port) - .env("WAZUH_USER", "admin") - .env("WAZUH_PASS", "admin") - .env("VERIFY_SSL", "false") - .env("MCP_SERVER_PORT", port.to_string()) - .env("RUST_LOG", "info") - .env("SERVER_ID", server_id) - .spawn() - .expect("Failed to start MCP server") -} - -#[tokio::test] -async fn test_mcp_client_integration() -> Result<()> { - let mock_wazuh = MockWazuhServer::new(); - let wazuh_url = mock_wazuh.url(); - - let mcp_port = 8765; - let mut mcp_server = start_mcp_server(&wazuh_url, mcp_port); - - sleep(Duration::from_secs(2)).await; - - let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port)); - - let health_data = mcp_client.check_health().await?; - assert_eq!(health_data["status"], "ok"); - assert_eq!(health_data["service"], "wazuh-mcp-server"); - - let mcp_data = mcp_client.get_mcp_data().await?; - - assert_eq!(mcp_data.len(), 2); - - let first_message = &mcp_data[0]; - assert_eq!(first_message["protocol_version"], "1.0"); - assert_eq!(first_message["source"], "Wazuh"); - assert_eq!(first_message["event_type"], "alert"); - - let context = &first_message["context"]; - assert_eq!(context["id"], "12345"); - assert_eq!(context["category"], "intrusion_detection"); - assert_eq!(context["severity"], "high"); - assert_eq!( - context["description"], - "Possible intrusion attempt detected" - ); - - let data = &context["data"]; - assert_eq!(data["source_ip"], "192.168.1.100"); - assert_eq!(data["destination_ip"], "10.0.0.1"); - assert_eq!(data["port"], 22); - - let second_message = &mcp_data[1]; - let context = &second_message["context"]; - assert_eq!(context["id"], "67890"); - assert_eq!(context["category"], "malware"); - assert_eq!(context["severity"], "critical"); - assert_eq!(context["description"], "Malware detected on system"); - - let data = &context["data"]; - assert_eq!(data["file_path"], "/tmp/malicious.exe"); - assert_eq!(data["hash"], "abcdef123456"); - assert_eq!(data["signature"], "EICAR-Test-File"); - - mcp_server.kill().expect("Failed to kill MCP server"); - - Ok(()) -} diff --git a/tests/integration_test.rs b/tests/integration_test.rs deleted file mode 100644 index b61a2d9..0000000 --- a/tests/integration_test.rs +++ /dev/null @@ -1,420 +0,0 @@ -use anyhow::Result; -use chrono::{DateTime, Utc}; -use httpmock::prelude::*; -use once_cell::sync::Lazy; -use serde_json::json; -use std::net::TcpListener; -use std::process::{Child, Command}; -use std::sync::Mutex; -use std::time::Duration; -use tokio::time::sleep; - -mod mcp_client; -use mcp_client::{McpClient, McpClientTrait, McpMessage}; - -static TEST_MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); - -fn find_available_port() -> u16 { - let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to random port"); - let port = listener - .local_addr() - .expect("Failed to get local address") - .port(); - drop(listener); - port -} - -struct MockWazuhServer { - server: MockServer, -} - -impl MockWazuhServer { - fn new() -> Self { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(POST).path("/security/user/authenticate"); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "jwt": "mock.jwt.token" - })); - }); - - server.mock(|when, then| { - when.method(GET).path("/wazuh-alerts-*_search"); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "hits": { - "hits": [ - { - "_source": { - "id": "12345", - "category": "intrusion_detection", - "severity": "high", - "description": "Possible intrusion attempt detected", - "data": { - "source_ip": "192.168.1.100", - "destination_ip": "10.0.0.1", - "port": 22 - }, - "notes": "Test alert" - } - }, - { - "_source": { - "id": "67890", - "category": "malware", - "severity": "critical", - "description": "Malware detected on system", - "data": { - "file_path": "/tmp/malicious.exe", - "hash": "abcdef123456", - "signature": "EICAR-Test-File" - } - } - } - ] - } - })); - }); - - Self { server } - } - - fn url(&self) -> String { - self.server.url("") - } - - fn host(&self) -> String { - let url = self.url(); - let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect(); - parts[0].to_string() - } - - fn port(&self) -> u16 { - let url = self.url(); - let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect(); - parts[1].parse().unwrap() - } -} - -fn setup_mock_wazuh_server() -> MockServer { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(POST).path("/security/user/authenticate"); - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ "jwt": "mock.jwt.token" })); - }); - - server.mock(|when, then| { - when.method(GET) - .path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap()); - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "hits": { - "hits": [ - { - "_source": { - "id": "12345", - "timestamp": "2024-01-01T10:00:00.000Z", - "rule": { - "level": 9, - "description": "Possible intrusion attempt detected", - "groups": ["intrusion_detection", "pci_dss"] - }, - "agent": { "id": "001", "name": "test-agent" }, - "data": { - "source_ip": "192.168.1.100", - "destination_ip": "10.0.0.1", - "port": 22 - } - } - }, - { - "_source": { - "id": "67890", - "timestamp": "2024-01-01T11:00:00.000Z", - "rule": { - "level": 12, - "description": "Malware detected on system", - "groups": ["malware"] - }, - "agent": { "id": "002", "name": "another-agent" }, - "data": { - "file_path": "/tmp/malicious.exe", - "hash": "abcdef123456", - "signature": "EICAR-Test-File" - } - } - } - ] - } - })); - }); - - server -} - -fn get_host_port(server: &MockServer) -> (String, u16) { - let url = server.url(""); - let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect(); - let host = parts[0].to_string(); - let port = parts[1].parse().unwrap(); - (host, port) -} - -fn start_mcp_server(wazuh_host: &str, wazuh_port: u16, mcp_port: u16) -> Child { - Command::new("cargo") - .args(["run", "--"]) - .env("WAZUH_HOST", wazuh_host) - .env("WAZUH_PORT", wazuh_port.to_string()) - .env("WAZUH_USER", "admin") - .env("WAZUH_PASS", "admin") - .env("VERIFY_SSL", "false") - .env("MCP_SERVER_PORT", mcp_port.to_string()) - .env("RUST_LOG", "info") - .spawn() - .expect("Failed to start MCP server") -} - -#[tokio::test] - -async fn test_mcp_server_with_mock_wazuh() -> Result<()> { - let _guard = TEST_MUTEX.lock().unwrap(); - - let mock_wazuh_server = setup_mock_wazuh_server(); - let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server); - - let mcp_port = find_available_port(); - - let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port); - - sleep(Duration::from_secs(2)).await; - - let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port)); - - let health_data = mcp_client.check_health().await?; - assert_eq!(health_data["status"], "ok"); - assert_eq!(health_data["service"], "wazuh-mcp-server"); - - let mcp_data = mcp_client.get_mcp_data().await?; - - assert_eq!(mcp_data.len(), 2); - - let first_message: &McpMessage = &mcp_data[0]; - assert_eq!(first_message.protocol_version, "1.0"); - assert_eq!(first_message.source, "Wazuh"); - assert_eq!(first_message.event_type, "alert"); - - let context = &first_message.context; - assert_eq!(context["id"], "12345"); - assert_eq!(context["category"], "intrusion_detection"); - assert_eq!(context["severity"], "high"); - assert_eq!( - context["description"], - "Possible intrusion attempt detected" - ); - assert_eq!(context["agent"]["name"], "test-agent"); - - let data = &context["data"]; - assert_eq!(data["source_ip"], "192.168.1.100"); - assert_eq!(data["destination_ip"], "10.0.0.1"); - assert_eq!(data["port"], 22); - - let second_message = &mcp_data[1]; - let context = &second_message.context; - assert_eq!(context["id"], "67890"); - assert_eq!(context["category"], "malware"); - assert_eq!(context["severity"], "critical"); - assert_eq!(context["description"], "Malware detected on system"); - assert_eq!(context["agent"]["name"], "another-agent"); - - let data = &context["data"]; - assert_eq!(data["file_path"], "/tmp/malicious.exe"); - assert_eq!(data["hash"], "abcdef123456"); - assert_eq!(data["signature"], "EICAR-Test-File"); - - mcp_server.kill().expect("Failed to kill MCP server"); - - Ok(()) -} - -#[tokio::test] -async fn test_mcp_server_wazuh_api_error() -> Result<()> { - let _guard = TEST_MUTEX.lock().unwrap(); - - let mock_wazuh_server = setup_mock_wazuh_server(); - let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server); - - mock_wazuh_server.mock(|when, then| { - when.method(GET) - .path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap()); - then.status(500) - .header("content-type", "application/json") - .json_body(json!({"error": "Wazuh internal error"})); - }); - - let mcp_port = find_available_port(); - let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port); - sleep(Duration::from_secs(2)).await; - - let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port)); - - let result = mcp_client.get_mcp_data().await; - assert!(result.is_err()); - let err_string = result.unwrap_err().to_string(); - assert!( - err_string.contains("500") - || err_string.contains("502") - || err_string.contains("API request failed") - ); - - let health_result = mcp_client.check_health().await; - assert!(health_result.is_ok()); - assert_eq!(health_result.unwrap()["status"], "ok"); - - mcp_server.kill().expect("Failed to kill MCP server"); - Ok(()) -} - -#[tokio::test] -async fn test_mcp_client_error_handling() -> Result<()> { - let _guard = TEST_MUTEX.lock().unwrap(); - - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/mcp"); - then.status(500) - .header("content-type", "application/json") - .json_body(json!({ - "error": "Internal server error" - })); - }); - - server.mock(|when, then| { - when.method(GET).path("/health"); - then.status(503) - .header("content-type", "application/json") - .json_body(json!({ - "error": "Service unavailable" - })); - }); - - let client = McpClient::new(server.url("")); - - let result = client.get_mcp_data().await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!(err.to_string().contains("500") || err.to_string().contains("MCP request failed")); - - let result = client.check_health().await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!(err.to_string().contains("503") || err.to_string().contains("Health check failed")); - - Ok(()) -} - -#[tokio::test] -async fn test_mcp_server_missing_alert_data() -> Result<()> { - let _guard = TEST_MUTEX.lock().unwrap(); - - let mock_wazuh_server = setup_mock_wazuh_server(); - let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server); - - mock_wazuh_server.mock(|when, then| { - when.method(GET) - .path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap()); - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "hits": { - "hits": [ - { - "_source": { - "id": "missing_all", - "timestamp": "invalid-date-format" - } - }, - { - "_source": { - "id": "missing_rule_fields", - "timestamp": "2024-05-05T11:00:00.000Z", - "rule": { }, - "agent": { "id": "003", "name": "agent-minimal" }, - "data": {} - } - }, - { - "id": "no_source_nest", - "timestamp": "2024-05-05T12:00:00.000Z", - "rule": { - "level": 2, - "description": "Low severity event", - "groups": ["low_sev"] - }, - "agent": { "id": "004" }, - "data": { "info": "some data" } - } - ] - } - })); - }); - - let mcp_port = find_available_port(); - let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port); - sleep(Duration::from_secs(2)).await; - - let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port)); - - let mcp_data = mcp_client.get_mcp_data().await?; - assert_eq!(mcp_data.len(), 3); - - let msg1 = &mcp_data[0]; - assert_eq!(msg1.context["id"], "missing_all"); - assert_eq!(msg1.context["category"], "unknown_category"); - assert_eq!(msg1.context["severity"], "unknown_severity"); - assert_eq!(msg1.context["description"], ""); - assert!( - msg1.context["agent"].is_object() && msg1.context["agent"].as_object().unwrap().is_empty() - ); - assert!( - msg1.context["data"].is_object() && msg1.context["data"].as_object().unwrap().is_empty() - ); - let ts1 = DateTime::parse_from_rfc3339(&msg1.timestamp) - .unwrap() - .with_timezone(&Utc); - assert!((Utc::now() - ts1).num_seconds() < 5); - - let msg2 = &mcp_data[1]; - assert_eq!(msg2.context["id"], "missing_rule_fields"); - assert_eq!(msg2.context["category"], "unknown_category"); - assert_eq!(msg2.context["severity"], "unknown_severity"); - assert_eq!(msg2.context["description"], ""); - assert_eq!(msg2.context["agent"]["name"], "agent-minimal"); - assert!( - msg2.context["data"].is_object() && msg2.context["data"].as_object().unwrap().is_empty() - ); - assert_eq!(msg2.timestamp, "2024-05-05T11:00:00Z"); - - let msg3 = &mcp_data[2]; - assert_eq!(msg3.context["id"], "no_source_nest"); - assert_eq!(msg3.context["category"], "low_sev"); - assert_eq!(msg3.context["severity"], "low"); - assert_eq!(msg3.context["description"], "Low severity event"); - assert_eq!(msg3.context["agent"]["id"], "004"); - assert!(msg3.context["agent"].get("name").is_none()); - assert_eq!(msg3.context["data"]["info"], "some data"); - assert_eq!(msg3.timestamp, "2024-05-05T12:00:00Z"); - - mcp_server.kill().expect("Failed to kill MCP server"); - Ok(()) -} diff --git a/tests/mcp_client.rs b/tests/mcp_client.rs deleted file mode 100644 index 1ab7ef7..0000000 --- a/tests/mcp_client.rs +++ /dev/null @@ -1,180 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use reqwest::{Client, StatusCode}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::time::Duration; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct McpMessage { - pub protocol_version: String, - pub source: String, - pub timestamp: String, - pub event_type: String, - pub context: Value, - pub metadata: Value, -} - -#[async_trait] -pub trait McpClientTrait { - async fn get_mcp_data(&self) -> Result>; - - async fn check_health(&self) -> Result; - - async fn query_mcp_data(&self, filters: Value) -> Result>; -} - -pub struct McpClient { - client: Client, - base_url: String, -} - -impl McpClient { - pub fn new(base_url: String) -> Self { - let client = Client::builder() - .timeout(Duration::from_secs(30)) - .build() - .expect("Failed to create HTTP client"); - - Self { client, base_url } - } -} - -#[async_trait] -impl McpClientTrait for McpClient { - async fn get_mcp_data(&self) -> Result> { - let url = format!("{}/mcp", self.base_url); - let response = self.client.get(&url).send().await?; - - match response.status() { - StatusCode::OK => { - let data = response.json::>().await?; - Ok(data) - } - status => { - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - anyhow::bail!("MCP request failed with status {}: {}", status, error_text) - } - } - } - - async fn check_health(&self) -> Result { - let url = format!("{}/health", self.base_url); - let response = self.client.get(&url).send().await?; - - match response.status() { - StatusCode::OK => { - let data = response.json::().await?; - Ok(data) - } - status => { - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - anyhow::bail!("Health check failed with status {}: {}", status, error_text) - } - } - } - - async fn query_mcp_data(&self, filters: Value) -> Result> { - let url = format!("{}/mcp", self.base_url); - let response = self.client.post(&url).json(&filters).send().await?; - - match response.status() { - StatusCode::OK => { - let data = response.json::>().await?; - Ok(data) - } - status => { - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - anyhow::bail!("MCP query failed with status {}: {}", status, error_text) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use httpmock::prelude::*; - use serde_json::json; - use tokio; - - #[tokio::test] - async fn test_mcp_client_get_data() { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/mcp"); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!([ - { - "protocol_version": "1.0", - "source": "Wazuh", - "timestamp": "2023-05-01T12:00:00Z", - "event_type": "alert", - "context": { - "id": "12345", - "category": "intrusion_detection", - "severity": "high", - "description": "Test alert", - "data": { - "source_ip": "192.168.1.100" - } - }, - "metadata": { - "integration": "Wazuh-MCP", - "notes": "Test note" - } - } - ])); - }); - - let client = McpClient::new(server.url("")); - - let result = client.get_mcp_data().await.unwrap(); - - assert_eq!(result.len(), 1); - assert_eq!(result[0].protocol_version, "1.0"); - assert_eq!(result[0].source, "Wazuh"); - assert_eq!(result[0].event_type, "alert"); - - let context = &result[0].context; - assert_eq!(context["id"], "12345"); - assert_eq!(context["category"], "intrusion_detection"); - assert_eq!(context["severity"], "high"); - } - - #[tokio::test] - async fn test_mcp_client_health_check() { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/health"); - - then.status(200) - .header("content-type", "application/json") - .json_body(json!({ - "status": "ok", - "service": "wazuh-mcp-server", - "timestamp": "2023-05-01T12:00:00Z" - })); - }); - - let client = McpClient::new(server.url("")); - - let result = client.check_health().await.unwrap(); - - assert_eq!(result["status"], "ok"); - assert_eq!(result["service"], "wazuh-mcp-server"); - } -} diff --git a/tests/mcp_client_cli.rs b/tests/mcp_client_cli.rs deleted file mode 100644 index 2733f60..0000000 --- a/tests/mcp_client_cli.rs +++ /dev/null @@ -1,164 +0,0 @@ -use anyhow::{anyhow, Result}; -use clap::Parser; -use serde_json::Value; -use std::io::{self, Write}; // For stdout().flush() and stdin().read_line() - -use mcp_server_wazuh::mcp::client::{McpClient, McpClientTrait}; -use serde::Deserialize; // For ParsedRequest - -#[derive(Parser, Debug)] -#[clap( - name = "mcp-client-cli", - version = "0.1.0", - about = "Interactive CLI for MCP server. Enter JSON-RPC requests, or 'health'/'quit'." -)] -struct CliArgs { - #[clap(long, help = "Path to the MCP server executable for stdio mode.")] - stdio_exe: Option, - - #[clap( - long, - env = "MCP_SERVER_URL", - default_value = "http://localhost:8000", - help = "URL of the MCP server for HTTP mode." - )] - http_url: String, -} - -// For parsing raw JSON request strings -#[derive(Deserialize, Debug)] -struct ParsedRequest { - // jsonrpc: String, // Not strictly needed for sending - method: String, - params: Option, - id: Value, // ID can be string or number -} - -#[tokio::main] -async fn main() -> Result<()> { - let cli_args = CliArgs::parse(); - let mut client: McpClient; - let is_stdio_mode = cli_args.stdio_exe.is_some(); - - if let Some(ref exe_path) = cli_args.stdio_exe { - println!("Using stdio mode with executable: {}", exe_path); - client = McpClient::new_stdio(exe_path, None).await?; - println!("Sending 'initialize' request to stdio server..."); - match client.initialize().await { - Ok(init_result) => { - println!("Initialization successful:"); - println!(" Protocol Version: {}", init_result.protocol_version); - println!(" Server Name: {}", init_result.server_info.name); - println!(" Server Version: {}", init_result.server_info.version); - } - Err(e) => { - eprintln!("Stdio Initialization failed: {}. You may need to send a raw 'initialize' JSON-RPC request or check server logs.", e); - // Allow continuing, user might want to send raw init or other commands. - } - } - } else { - println!("Using HTTP mode with URL: {}", cli_args.http_url); - client = McpClient::new_http(cli_args.http_url.clone()); - // No automatic initialize for HTTP mode as per McpClientTrait. - // `initialize` is typically a stdio-specific concept in MCP. - } - - println!("\nInteractive MCP Client. Enter a JSON-RPC request, 'health' (HTTP only), or 'quit'."); - println!("Press CTRL-D for EOF to exit."); - - let mut input_buffer = String::new(); - loop { - input_buffer.clear(); - print!("mcp> "); - io::stdout().flush().map_err(|e| anyhow!("Failed to flush stdout: {}", e))?; - - match io::stdin().read_line(&mut input_buffer) { - Ok(0) => { // EOF (Ctrl-D) - println!("\nEOF detected. Exiting."); - break; - } - Ok(_) => { - let line = input_buffer.trim(); - if line.is_empty() { - continue; - } - - if line.eq_ignore_ascii_case("quit") { - println!("Exiting."); - break; - } - - if line.eq_ignore_ascii_case("health") { - if is_stdio_mode { - println!("'health' command is intended for HTTP mode. For stdio, you would need to send a specific JSON-RPC request if the server supports a health method via stdio."); - } else { - println!("Checking server health (HTTP GET to /health)..."); - let health_url = format!("{}/health", cli_args.http_url); // Use the parsed http_url - match reqwest::get(&health_url).await { - Ok(response) => { - let status = response.status(); - let response_text = response.text().await.unwrap_or_else(|_| "Failed to read response body".to_string()); - if status.is_success() { - match serde_json::from_str::(&response_text) { - Ok(json_val) => println!("Health response ({}):\n{}", status, serde_json::to_string_pretty(&json_val).unwrap_or_else(|_| response_text.clone())), - Err(_) => println!("Health response ({}):\n{}", status, response_text), - } - } else { - eprintln!("Health check failed with status: {}", status); - eprintln!("Response: {}", response_text); - } - } - Err(e) => eprintln!("Health check request failed: {}", e), - } - } - continue; - } - - // Assume it's a JSON-RPC request - println!("Attempting to send as JSON-RPC: {}", line); - match serde_json::from_str::(line) { - Ok(parsed_req) => { - match client - .send_json_rpc_request( - &parsed_req.method, - parsed_req.params.clone(), - parsed_req.id.clone(), - ) - .await - { - Ok(response_value) => { - println!( - "Server Response: {}", - serde_json::to_string_pretty(&response_value).unwrap_or_else( - |e_pretty| format!("Failed to pretty-print response ({}): {:?}", e_pretty, response_value) - ) - ); - } - Err(e) => { - eprintln!("Error processing JSON-RPC request '{}': {}", line, e); - } - } - } - Err(e) => { - eprintln!("Failed to parse input as a JSON-RPC request: {}. Input: '{}'", e, line); - eprintln!("Please enter a valid JSON-RPC request string, 'health', or 'quit'."); - } - } - } - Err(e) => { - eprintln!("Error reading input: {}. Exiting.", e); - break; - } - } - } - - if is_stdio_mode { - println!("Sending 'shutdown' request to stdio server..."); - match client.shutdown().await { - Ok(_) => println!("Shutdown command acknowledged by server."), - Err(e) => eprintln!("Error during shutdown: {}. Server might have already exited or closed the connection.", e), - } - } - - Ok(()) -} diff --git a/tests/mcp_stdio_test.rs b/tests/mcp_stdio_test.rs new file mode 100644 index 0000000..1467022 --- /dev/null +++ b/tests/mcp_stdio_test.rs @@ -0,0 +1,361 @@ +//! Tests for MCP protocol communication via stdio +//! +//! These tests verify the basic MCP protocol implementation without +//! requiring a Wazuh connection. + +use std::process::{Command, Stdio}; +use std::io::{BufRead, BufReader, Write}; +use std::time::Duration; +use tokio::time::sleep; +use serde_json::{json, Value}; + +struct McpStdioClient { + child: std::process::Child, + stdin: std::process::ChildStdin, + stdout: BufReader, +} + +impl McpStdioClient { + fn start() -> Result> { + let mut child = Command::new("cargo") + .args(["run", "--bin", "mcp-server-wazuh"]) + .env("WAZUH_HOST", "nonexistent.example.com") // Use non-existent host + .env("WAZUH_PORT", "9999") + .env("WAZUH_USER", "test") + .env("WAZUH_PASS", "test") + .env("VERIFY_SSL", "false") + .env("RUST_LOG", "error") // Minimize logging noise + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) // Changed from Stdio::null() to inherit stderr + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = BufReader::new(child.stdout.take().unwrap()); + + Ok(McpStdioClient { + child, + stdin, + stdout, + }) + } + + fn send_message(&mut self, message: &Value) -> Result<(), Box> { + let message_str = serde_json::to_string(message)?; + writeln!(self.stdin, "{}", message_str)?; + self.stdin.flush()?; + Ok(()) + } + + fn read_response(&mut self) -> Result> { + let mut line = String::new(); + self.stdout.read_line(&mut line)?; + let response: Value = serde_json::from_str(&line.trim())?; + Ok(response) + } + + fn send_and_receive(&mut self, message: &Value) -> Result> { + self.send_message(message)?; + self.read_response() + } +} + +impl Drop for McpStdioClient { + fn drop(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +#[tokio::test] +async fn test_mcp_protocol_initialization() -> Result<(), Box> { + let mut client = McpStdioClient::start()?; + + // Give the server time to start + sleep(Duration::from_millis(500)).await; + + // Test initialize request + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let response = client.send_and_receive(&init_request)?; + + // Verify JSON-RPC 2.0 compliance + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 1); + assert!(response["result"].is_object()); + assert!(response["error"].is_null()); + + // Verify MCP initialize response structure + let result = &response["result"]; + assert_eq!(result["protocolVersion"], "2024-11-05"); + assert!(result["capabilities"].is_object()); + assert!(result["serverInfo"].is_object()); + + // Verify server info + let server_info = &result["serverInfo"]; + assert!(server_info["name"].is_string()); + assert!(server_info["version"].is_string()); + + Ok(()) +} + +#[tokio::test] +async fn test_mcp_tools_list_without_wazuh() -> Result<(), Box> { + let mut client = McpStdioClient::start()?; + + sleep(Duration::from_millis(500)).await; + + // Initialize first + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + client.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + client.send_message(&initialized)?; + + // Request tools list + let tools_request = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + }); + + let response = client.send_and_receive(&tools_request)?; + + // Verify response structure + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 2); + assert!(response["result"].is_object()); + + let result = &response["result"]; + assert!(result["tools"].is_array()); + + let tools = result["tools"].as_array().unwrap(); + assert!(!tools.is_empty()); + + // Verify tool structure + for tool in tools { + assert!(tool["name"].is_string()); + assert!(tool["description"].is_string()); + assert!(tool["inputSchema"].is_object()); + } + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_json_rpc_request() -> Result<(), Box> { + let mut client = McpStdioClient::start()?; + + sleep(Duration::from_millis(500)).await; + + // 1. Initialize the connection first + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"} + } + }); + let _init_response = client.send_and_receive(&init_request)?; // Read and ignore/assert init response + // assert!(_init_response["result"].is_object()); // Optional: assert successful init + + // 2. Send initialized notification + let initialized_notification = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + client.send_message(&initialized_notification)?; + + // 3. Send the invalid JSON-RPC request (missing required fields) + let invalid_request = json!({ + // "jsonrpc": "2.0", // Missing jsonrpc field to make it invalid + "id": 2, // Use a new ID + "method": "some_method_that_might_not_exist" + }); + client.send_message(&invalid_request)?; + + // 4. The server currently closes the connection upon such an invalid request (see logs: + // `ERROR rmcp::transport::io ... serde error ...` followed by `input stream terminated`). + // Therefore, subsequent requests should fail. This test verifies this behavior. + // Ideally, the server might send a JSON-RPC error and keep the connection open, + // but that would require changes to the server's error handling logic. + + // 5. Attempt to send a subsequent valid request. + let list_tools_request = json!({ + "jsonrpc": "2.0", + "id": 3, // New ID + "method": "tools/list", + "params": {} + }); + + let result = client.send_and_receive(&list_tools_request); + + // Assert that the operation failed, indicating the connection was likely closed. + assert!(result.is_err(), "Server should have closed the connection after the invalid request, leading to an error here."); + + // Optionally, check the error type more specifically if needed, e.g., for EOF. + if let Err(e) = result { + let error_message = e.to_string().to_lowercase(); + assert!( + error_message.contains("eof") || error_message.contains("broken pipe") || error_message.contains("connection reset"), + "Expected EOF, broken pipe, or connection reset error, but got: {}", e + ); + } + + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_method() -> Result<(), Box> { + let mut client = McpStdioClient::start()?; + + sleep(Duration::from_millis(500)).await; + + // Initialize first + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + client.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized_notification = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + client.send_message(&initialized_notification)?; + + let unsupported_request = json!({ + "jsonrpc": "2.0", + "id": 2, + + "method": "unsupported/method" + // Omitting "params": {} as it might be causing deserialization issues + // in rmcp for unknown methods. The JSON-RPC spec allows params to be omitted. + }); + + // Send the unsupported request. We don't expect a valid JSON-RPC response. + // Instead, the server is likely to close the connection due to deserialization issues + // in rmcp when encountering an unknown method, as it cannot match it to a known JsonRpcMessage variant. + client.send_message(&unsupported_request)?; + + // Attempt to send a subsequent valid request to confirm the connection was dropped. + let list_tools_request = json!({ + "jsonrpc": "2.0", + "id": 3, // Use a new ID + "method": "tools/list", + "params": {} + }); + + let result = client.send_and_receive(&list_tools_request); + + // Assert that the operation failed, indicating the connection was likely closed. + assert!(result.is_err(), "Server should have closed the connection after the unsupported method request, leading to an error here."); + + // Optionally, check the error type more specifically if needed, e.g., for EOF. + if let Err(e) = result { + let error_message = e.to_string().to_lowercase(); + assert!( + error_message.contains("eof") || error_message.contains("broken pipe") || error_message.contains("connection reset"), + "Expected EOF, broken pipe, or connection reset error, but got: {}", e + ); + } + + Ok(()) +} + +#[tokio::test] +async fn test_concurrent_requests() -> Result<(), Box> { + let mut client = McpStdioClient::start()?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + client.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + client.send_message(&initialized)?; + + // Send multiple requests with different IDs + let request1 = json!({ + "jsonrpc": "2.0", + "id": 10, + "method": "tools/list", + "params": {} + }); + + let request2 = json!({ + "jsonrpc": "2.0", + "id": 20, + "method": "tools/list", + "params": {} + }); + + // Send both requests + client.send_message(&request1)?; + client.send_message(&request2)?; + + // Read both responses + let response1 = client.read_response()?; + let response2 = client.read_response()?; + + // Responses should maintain request IDs (though order might vary) + let ids: Vec = vec![ + response1["id"].as_i64().unwrap(), + response2["id"].as_i64().unwrap(), + ]; + + assert!(ids.contains(&10)); + assert!(ids.contains(&20)); + + Ok(()) +} diff --git a/tests/mock_wazuh_server.rs b/tests/mock_wazuh_server.rs new file mode 100644 index 0000000..0aa3da9 --- /dev/null +++ b/tests/mock_wazuh_server.rs @@ -0,0 +1,340 @@ +//! Mock Wazuh API server for testing +//! +//! This module provides a configurable mock server that simulates the Wazuh Indexer API +//! for testing purposes. It supports various response scenarios including success, +//! empty results, and error conditions. + +use httpmock::prelude::*; +use serde_json::json; + +pub struct MockWazuhServer { + server: MockServer, +} + +impl MockWazuhServer { + pub fn new() -> Self { + let server = MockServer::start(); + + // Setup default authentication endpoint + server.mock(|when, then| { + when.method(POST).path("/security/user/authenticate"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "jwt": "mock.jwt.token.eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" + })); + }); + + server.mock(|when, then| { + when.method(POST) + .path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap()); + then.status(200) + .header("content-type", "application/json") + .json_body(Self::sample_alerts_response()); + }); + + Self { server } + } + + pub fn with_empty_alerts() -> Self { + let server = MockServer::start(); + + server.mock(|when, then| { + when.method(POST).path("/security/user/authenticate"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "jwt": "mock.jwt.token" + })); + }); + + server.mock(|when, then| { + when.method(POST) + .path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap()); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "hits": { + "hits": [] + } + })); + }); + + Self { server } + } + + pub fn with_auth_error() -> Self { + let server = MockServer::start(); + + server.mock(|when, then| { + when.method(POST).path("/security/user/authenticate"); + then.status(401) + .header("content-type", "application/json") + .json_body(json!({ + "error": "Invalid credentials" + })); + }); + + Self { server } + } + + pub fn with_alerts_error() -> Self { + let server = MockServer::start(); + + server.mock(|when, then| { + when.method(POST).path("/security/user/authenticate"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "jwt": "mock.jwt.token" + })); + }); + + server.mock(|when, then| { + when.method(POST) + .path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap()); + then.status(500) + .header("content-type", "application/json") + .json_body(json!({ + "error": "Internal server error" + })); + }); + + Self { server } + } + + pub fn with_malformed_alerts() -> Self { + let server = MockServer::start(); + + server.mock(|when, then| { + when.method(POST).path("/security/user/authenticate"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "jwt": "mock.jwt.token" + })); + }); + + server.mock(|when, then| { + when.method(POST) + .path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap()); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "hits": { + "hits": [ + { + "_source": { + "id": "missing_fields", + "timestamp": "invalid-date-format" + // Missing rule, agent, etc. + } + }, + { + "_source": { + "id": "partial_data", + "timestamp": "2024-01-15T10:30:45.123Z", + "rule": { + "level": 5 + // Missing description + }, + "agent": { + "id": "001" + // Missing name + } + } + } + ] + } + })); + }); + + Self { server } + } + + pub fn url(&self) -> String { + self.server.url("") + } + + pub fn host(&self) -> String { + let url = self.url(); + let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect(); + parts[0].to_string() + } + + pub fn port(&self) -> u16 { + let url = self.url(); + let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect(); + parts[1].parse().unwrap() + } + + fn sample_alerts_response() -> serde_json::Value { + json!({ + "hits": { + "hits": [ + { + "_source": { + "id": "1747091815.1212763", + "timestamp": "2024-01-15T10:30:45.123Z", + "rule": { + "level": 7, + "description": "Attached USB Storage", + "groups": ["usb", "pci_dss"] + }, + "agent": { + "id": "001", + "name": "web-server-01" + }, + "data": { + "device": "/dev/sdb1", + "mount_point": "/media/usb" + } + } + }, + { + "_source": { + "id": "1747066333.1207112", + "timestamp": "2024-01-15T10:25:12.456Z", + "rule": { + "level": 5, + "description": "New dpkg (Debian Package) installed.", + "groups": ["package_management", "debian"] + }, + "agent": { + "id": "002", + "name": "database-server" + }, + "data": { + "package": "nginx", + "version": "1.18.0-6ubuntu14.4" + } + } + }, + { + "_source": { + "id": "1747055444.1205998", + "timestamp": "2024-01-15T10:20:33.789Z", + "rule": { + "level": 12, + "description": "Multiple authentication failures", + "groups": ["authentication_failed", "pci_dss"] + }, + "agent": { + "id": "003", + "name": "ssh-gateway" + }, + "data": { + "source_ip": "192.168.1.100", + "user": "admin", + "attempts": 5 + } + } + } + ] + } + }) + } +} + +impl Default for MockWazuhServer { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_server_creation() { + let mock_server = MockWazuhServer::new(); + assert!(!mock_server.url().is_empty()); + assert!(!mock_server.host().is_empty()); + assert!(mock_server.port() > 0); + } + + #[tokio::test] + async fn test_mock_server_auth_endpoint() { + let mock_server = MockWazuhServer::new(); + let client = reqwest::Client::new(); + + let response = client + .post(&format!("{}/security/user/authenticate", mock_server.url())) + .json(&json!({"username": "admin", "password": "admin"})) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + let body: serde_json::Value = response.json().await.unwrap(); + assert!(body.get("jwt").is_some()); + } + + #[tokio::test] + async fn test_mock_server_alerts_endpoint() { + let mock_server = MockWazuhServer::new(); + let client = reqwest::Client::new(); + + let response = client + .post(&format!("{}/wazuh-alerts*/_search", mock_server.url())) + .json(&json!({"query": {"match_all": {}}})) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + let body: serde_json::Value = response.json().await.unwrap(); + assert!(body.get("hits").is_some()); + let hits = body["hits"]["hits"].as_array().unwrap(); + assert!(!hits.is_empty()); + } + + #[tokio::test] + async fn test_empty_alerts_server() { + let mock_server = MockWazuhServer::with_empty_alerts(); + let client = reqwest::Client::new(); + + let response = client + .post(&format!("{}/wazuh-alerts*/_search", mock_server.url())) + .json(&json!({"query": {"match_all": {}}})) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 200); + let body: serde_json::Value = response.json().await.unwrap(); + let hits = body["hits"]["hits"].as_array().unwrap(); + assert!(hits.is_empty()); + } + + #[tokio::test] + async fn test_auth_error_server() { + let mock_server = MockWazuhServer::with_auth_error(); + let client = reqwest::Client::new(); + + let response = client + .post(&format!("{}/security/user/authenticate", mock_server.url())) + .json(&json!({"username": "admin", "password": "wrong"})) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 401); + } + + #[tokio::test] + async fn test_alerts_error_server() { + let mock_server = MockWazuhServer::with_alerts_error(); + let client = reqwest::Client::new(); + + let response = client + .post(&format!("{}/wazuh-alerts*/_search", mock_server.url())) + .json(&json!({"query": {"match_all": {}}})) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), 500); + } +} diff --git a/tests/rmcp_integration_test.rs b/tests/rmcp_integration_test.rs new file mode 100644 index 0000000..0a829be --- /dev/null +++ b/tests/rmcp_integration_test.rs @@ -0,0 +1,546 @@ +//! Integration tests for the rmcp-based Wazuh MCP Server +//! +//! These tests verify the MCP server functionality using a mock Wazuh API server. +//! Tests cover tool registration, parameter validation, alert retrieval, and error handling. + +use std::process::{Child, Command, Stdio}; +use std::io::{BufRead, BufReader, Write}; +use std::time::Duration; +use tokio::time::sleep; +use serde_json::{json, Value}; +use once_cell::sync::Lazy; +use std::sync::Mutex; + +mod mock_wazuh_server; +use mock_wazuh_server::MockWazuhServer; + +static TEST_MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); + +struct McpServerProcess { + child: Child, + stdin: std::process::ChildStdin, + stdout: BufReader, +} + +impl McpServerProcess { + fn start_with_mock_wazuh(mock_server: &MockWazuhServer) -> Result> { + let mut child = Command::new("cargo") + .args(["run", "--bin", "mcp-server-wazuh"]) + .env("WAZUH_HOST", mock_server.host()) + .env("WAZUH_PORT", mock_server.port().to_string()) + .env("WAZUH_USER", "admin") + .env("WAZUH_PASS", "admin") + .env("VERIFY_SSL", "false") + .env("WAZUH_TEST_PROTOCOL", "http") + .env("RUST_LOG", "warn") // Reduce noise in tests + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) // Inherit stderr to see server logs + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = BufReader::new(child.stdout.take().unwrap()); + + Ok(McpServerProcess { + child, + stdin, + stdout, + }) + } + + fn send_message(&mut self, message: &Value) -> Result<(), Box> { + let message_str = serde_json::to_string(message)?; + writeln!(self.stdin, "{}", message_str)?; + self.stdin.flush()?; + Ok(()) + } + + fn read_response(&mut self) -> Result> { + let mut line = String::new(); + self.stdout.read_line(&mut line)?; + let response: Value = serde_json::from_str(line.trim())?; + Ok(response) + } + + fn send_and_receive(&mut self, message: &Value) -> Result> { + self.send_message(message)?; + self.read_response() + } +} + +impl Drop for McpServerProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +#[tokio::test] +async fn test_mcp_server_initialization() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::new(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + // Give the server time to start + sleep(Duration::from_millis(500)).await; + + // Send initialize request + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let response = mcp_server.send_and_receive(&init_request)?; + + // Verify response structure + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 1); + assert!(response["result"].is_object()); + + let result = &response["result"]; + assert_eq!(result["protocolVersion"], "2024-11-05"); + assert!(result["capabilities"].is_object()); + assert!(result["serverInfo"].is_object()); + assert!(result["instructions"].is_string()); + + Ok(()) +} + +#[tokio::test] +async fn test_tools_list() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::new(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize first + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Request tools list + let tools_request = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {} + }); + + let response = mcp_server.send_and_receive(&tools_request)?; + + // Verify response + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 2); + assert!(response["result"]["tools"].is_array()); + + let tools = response["result"]["tools"].as_array().unwrap(); + assert!(!tools.is_empty()); + + // Check for our Wazuh alert summary tool + let alert_tool = tools.iter() + .find(|tool| tool["name"] == "get_wazuh_alert_summary") + .expect("get_wazuh_alert_summary tool should be present"); + + assert!(alert_tool["description"].is_string()); + assert!(alert_tool["inputSchema"].is_object()); + + // Verify input schema structure + let input_schema = &alert_tool["inputSchema"]; + assert_eq!(input_schema["type"], "object"); + assert!(input_schema["properties"].is_object()); + assert!(input_schema["properties"]["limit"].is_object()); + + Ok(()) +} + +#[tokio::test] +async fn test_get_wazuh_alert_summary_success() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::new(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Call the tool + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_wazuh_alert_summary", + "arguments": { + "limit": 2 + } + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Verify response structure + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + assert!(response["result"].is_object()); + + let result = &response["result"]; + assert!(result["content"].is_array()); + assert_eq!(result["isError"], false); + + let content = result["content"].as_array().unwrap(); + assert!(!content.is_empty()); + + // Verify content format + for item in content { + assert_eq!(item["type"], "text"); + assert!(item["text"].is_string()); + + let text = item["text"].as_str().unwrap(); + assert!(text.contains("Alert ID:")); + assert!(text.contains("Time:")); + assert!(text.contains("Agent:")); + assert!(text.contains("Level:")); + assert!(text.contains("Description:")); + } + + Ok(()) +} + +#[tokio::test] +async fn test_get_wazuh_alert_summary_empty_results() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::with_empty_alerts(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Call the tool + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_wazuh_alert_summary", + "arguments": {} + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Verify response + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + + let result = &response["result"]; + assert!(result["content"].is_array()); + assert_eq!(result["isError"], false); + + let content = result["content"].as_array().unwrap(); + assert_eq!(content.len(), 1); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "No Wazuh alerts found."); + + Ok(()) +} + +#[tokio::test] +async fn test_get_wazuh_alert_summary_api_error() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::with_alerts_error(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Call the tool + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_wazuh_alert_summary", + "arguments": { + "limit": 5 + } + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Verify error response + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + + let result = &response["result"]; + assert!(result["content"].is_array()); + assert_eq!(result["isError"], true); + + let content = result["content"].as_array().unwrap(); + assert_eq!(content.len(), 1); + assert_eq!(content[0]["type"], "text"); + + let error_text = content[0]["text"].as_str().unwrap(); + assert!(error_text.contains("Error retrieving alerts from Wazuh")); + + Ok(()) +} + +#[tokio::test] +async fn test_invalid_tool_call() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::new(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Call non-existent tool + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "non_existent_tool", + "arguments": {} + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Should get an error response + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + assert!(response["error"].is_object()); + + Ok(()) +} + +#[tokio::test] +async fn test_parameter_validation() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::new(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Test with invalid parameter type (string instead of number) + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_wazuh_alert_summary", + "arguments": { + "limit": "invalid" + } + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Should get an error response for invalid parameters + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + // The response might be an error or a successful response with error content + // depending on how rmcp handles parameter validation + assert!(response["error"].is_object() || + (response["result"]["isError"] == true)); + + Ok(()) +} + +#[tokio::test] +async fn test_malformed_alert_data_handling() -> Result<(), Box> { + let _guard = TEST_MUTEX.lock().unwrap(); + + let mock_server = MockWazuhServer::with_malformed_alerts(); + let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?; + + sleep(Duration::from_millis(500)).await; + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test", "version": "1.0"} + } + }); + mcp_server.send_and_receive(&init_request)?; + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + mcp_server.send_message(&initialized)?; + + // Call the tool + let tool_call = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_wazuh_alert_summary", + "arguments": { + "limit": 5 + } + } + }); + + let response = mcp_server.send_and_receive(&tool_call)?; + + // Should handle malformed data gracefully + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 3); + + let result = &response["result"]; + assert!(result["content"].is_array()); + // Should not error out, but handle missing fields gracefully + assert_eq!(result["isError"], false); + + let content = result["content"].as_array().unwrap(); + assert!(!content.is_empty()); + + // Verify that missing fields are handled with defaults + for item in content { + assert_eq!(item["type"], "text"); + let text = item["text"].as_str().unwrap(); + // Should contain default values for missing fields + assert!(text.contains("Alert ID:")); + assert!(text.contains("Unknown") || text.contains("missing_fields") || text.contains("partial_data")); + } + + Ok(()) +} diff --git a/tests/run_tests.sh b/tests/run_tests.sh index bfc77dd..c44de8a 100755 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -1,47 +1,100 @@ #!/bin/bash -echo "Running all tests..." -cargo test +# Test script for Wazuh MCP Server (rmcp-based) +# This script runs various tests to ensure the server is working correctly -echo "Building MCP client CLI..." -cargo build --bin mcp_client_cli +set -e -echo "Building main server binary for stdio CLI tests..." -# Build the server executable that mcp_client_cli will run -cargo build --bin mcp-server-wazuh # Output: target/debug/mcp-server-wazuh +echo "Starting Wazuh MCP Server tests (rmcp-based)..." -echo "Testing MCP client CLI in stdio mode..." +# Set test environment variables +export RUST_LOG=info -echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh initialize" -./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh initialize -if [ $? -ne 0 ]; then - echo "CLI 'initialize' command failed!" - exit 1 -fi +echo "Environment variables set:" +echo " RUST_LOG: $RUST_LOG" -echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext" -./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext -if [ $? -ne 0 ]; then - echo "CLI 'provideContext' command failed!" - exit 1 -fi +# Function to cleanup background processes +cleanup() { + echo "Cleaning up..." + if [ ! -z "$SERVER_PID" ]; then + kill $SERVER_PID 2>/dev/null || true + wait $SERVER_PID 2>/dev/null || true + fi +} -# Example of provideContext with empty JSON params (optional to uncomment and test) -# echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext '{}'" -# ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext '{}' -# if [ $? -ne 0 ]; then -# echo "CLI 'provideContext {}' command failed!" -# exit 1 -# fi +# Set trap to cleanup on exit +trap cleanup EXIT -echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh shutdown" -./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh shutdown -if [ $? -ne 0 ]; then - # Shutdown might return an error if the server closes the pipe before the client fully processes the response, - # but the primary goal is that the server process is terminated. - # For this script, we'll be lenient on shutdown's exit code for now, - # as long as initialize and provideContext worked. - echo "CLI 'shutdown' command executed (non-zero exit code is sometimes expected if server closes pipe quickly)." -fi +echo "" +echo "=== Running Unit Tests ===" +cargo test --lib -echo "MCP client CLI stdio tests completed." +echo "" +echo "=== Running MCP Protocol Tests ===" +cargo test --test mcp_stdio_test + +echo "" +echo "=== Running Integration Tests with Mock Wazuh ===" +cargo test --test rmcp_integration_test + +echo "" +echo "=== Manual MCP Server Testing ===" + +# Test the server with a simple MCP interaction +echo "Testing MCP server initialization..." + +# Create a temporary test script +cat > /tmp/test_mcp_server.sh << 'INNER_EOF' +#!/bin/bash + +# Start the server in background +WAZUH_HOST=mock.example.com RUST_LOG=error cargo run --bin mcp-server-wazuh & +SERVER_PID=$! + +# Give server time to start +sleep 1 + +# Test MCP initialization +echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}' | timeout 5s nc -l 0 2>/dev/null || { + # If nc doesn't work, try a different approach + echo "Testing server startup..." + sleep 2 +} + +# Clean up +kill $SERVER_PID 2>/dev/null || true +wait $SERVER_PID 2>/dev/null || true + +echo "Manual test completed" +INNER_EOF + +chmod +x /tmp/test_mcp_server.sh +/tmp/test_mcp_server.sh +rm /tmp/test_mcp_server.sh + +echo "" +echo "=== Testing Server Binary ===" +echo "Verifying server binary can start and show help..." + +# Test that the binary can start and show help +timeout 5s cargo run --bin mcp-server-wazuh -- --help || echo "Help command test completed" + +echo "" +echo "=== All Tests Complete ===" +echo "" +echo "Test Summary:" +echo "✓ Unit tests for library components" +echo "✓ Wazuh client tests with mock HTTP server" +echo "✓ MCP protocol tests via stdio" +echo "✓ Integration tests with mock Wazuh API" +echo "✓ Server binary startup test" +echo "" +echo "To test manually with a real Wazuh instance:" +echo " export WAZUH_HOST=your-wazuh-host" +echo " export WAZUH_PORT=9200" +echo " export WAZUH_USER=admin" +echo " export WAZUH_PASS=your-password" +echo " cargo run --bin mcp-server-wazuh" +echo "" +echo "Then send MCP commands via stdin, for example:" +echo ' echo '"'"'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}'"'"' | cargo run --bin mcp-server-wazuh'