* Ported code to RMCP

* Implemented unit and e2e testing
* Other fixes and enhancements
This commit is contained in:
Gianluca Brigandi
2025-05-22 20:02:41 -07:00
parent 6661523c0f
commit d59d67b8db
28 changed files with 1519 additions and 2778 deletions

View File

@@ -1,135 +0,0 @@
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use serde_json::{json, Value};
use std::sync::Arc;
use tracing::{debug, error, info};
use crate::logging_utils::{log_mcp_request, log_mcp_response};
use crate::AppState; // Added
pub fn create_http_router(app_state: Arc<AppState>) -> Router {
Router::new()
.route("/health", get(health_check))
.route("/mcp", get(get_mcp_data))
.route("/mcp", post(post_mcp_data))
.with_state(app_state)
}
async fn health_check() -> impl IntoResponse {
Json(json!({
"status": "ok",
"service": "wazuh-mcp-server",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
async fn get_mcp_data(
State(app_state): State<Arc<AppState>>,
) -> Result<Json<Vec<Value>>, ApiError> {
info!("Handling GET /mcp request");
let wazuh_client = app_state.wazuh_client.lock().await;
match wazuh_client.get_alerts().await {
Ok(alerts) => {
// Transform Wazuh alerts to MCP messages
let mcp_messages = alerts
.iter()
.map(|alert| {
json!({
"protocol_version": "1.0",
"source": "Wazuh",
"timestamp": chrono::Utc::now().to_rfc3339(),
"event_type": "alert",
"context": alert,
"metadata": {
"integration": "Wazuh-MCP"
}
})
})
.collect::<Vec<_>>();
Ok(Json(mcp_messages))
}
Err(e) => {
error!("Error getting alerts from Wazuh: {}", e);
Err(ApiError::InternalServerError(format!(
"Failed to get alerts from Wazuh: {}",
e
)))
}
}
}
async fn post_mcp_data(
State(app_state): State<Arc<AppState>>,
Json(payload): Json<Value>,
) -> Result<Json<Vec<Value>>, ApiError> {
info!("Handling POST /mcp request with payload");
debug!("Payload: {:?}", payload);
// Log the incoming payload
let request_str = serde_json::to_string(&payload).unwrap_or_else(|e| {
error!(
"Failed to serialize POST request payload for logging: {}",
e
);
format!(
"{{\"error\":\"Failed to serialize request payload: {}\"}}",
e
)
});
log_mcp_request(&request_str);
let result = get_mcp_data(State(app_state)).await;
let response_str = match &result {
Ok(json_response) => serde_json::to_string(&json_response.0).unwrap_or_else(|e| {
error!("Failed to serialize POST response for logging: {}", e);
format!("{{\"error\":\"Failed to serialize response: {}\"}}", e)
}),
Err(api_error) => {
let error_json_surrogate = json!({
"error": format!("{:?}", api_error) // Or a more structured error
});
serde_json::to_string(&error_json_surrogate).unwrap_or_else(|e| {
error!("Failed to serialize POST error response for logging: {}", e);
format!(
"{{\"error\":\"Failed to serialize error response: {}\"}}",
e
)
})
}
};
log_mcp_response(&response_str);
result
}
#[derive(Debug)]
pub enum ApiError {
BadRequest(String),
NotFound(String),
InternalServerError(String),
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
ApiError::InternalServerError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
};
let body = Json(json!({
"error": error_message
}));
(status, body).into_response()
}
}

View File

@@ -1,8 +1,4 @@
// This file is kept for compatibility with existing tests and binaries
// The main MCP server functionality has been moved to main.rs using the rmcp framework
pub mod wazuh;
// Re-export for backward compatibility
pub use wazuh::client::WazuhIndexerClient;
pub use wazuh::error::WazuhApiError;

View File

@@ -1,27 +0,0 @@
use std::fs::OpenOptions;
use std::io::Write;
use tracing::error;
const REQUEST_LOG_FILE: &str = "mcp_requests.log";
const RESPONSE_LOG_FILE: &str = "mcp_responses.log";
fn log_to_file(filename: &str, message: &str) {
match OpenOptions::new().create(true).append(true).open(filename) {
Ok(mut file) => {
if let Err(e) = writeln!(file, "{}", message) {
error!("Failed to write to {}: {}", filename, e);
}
}
Err(e) => {
error!("Failed to open {} for appending: {}", filename, e);
}
}
}
pub fn log_mcp_request(request_str: &str) {
log_to_file(REQUEST_LOG_FILE, request_str);
}
pub fn log_mcp_response(response_str: &str) {
log_to_file(RESPONSE_LOG_FILE, response_str);
}

View File

@@ -44,7 +44,6 @@ use rmcp::{
schemars, tool,
transport::stdio,
};
use serde_json::json;
use std::sync::Arc;
use std::env;
use clap::Parser;
@@ -93,12 +92,16 @@ impl WazuhToolsServer {
.to_lowercase()
== "true";
let wazuh_client = WazuhIndexerClient::new(
let protocol = env::var("WAZUH_TEST_PROTOCOL").unwrap_or_else(|_| "https".to_string());
tracing::debug!(?protocol, "Using Wazuh protocol for client from WAZUH_TEST_PROTOCOL or default");
let wazuh_client = WazuhIndexerClient::new_with_protocol(
wazuh_host,
wazuh_port,
wazuh_user,
wazuh_pass,
verify_ssl,
&protocol,
);
Ok(Self {
@@ -121,72 +124,58 @@ impl WazuhToolsServer {
match self.wazuh_client.get_alerts().await {
Ok(raw_alerts) => {
let alerts_to_process: Vec<_> = raw_alerts.into_iter().take(limit as usize).collect();
let content_items: Vec<serde_json::Value> = if alerts_to_process.is_empty() {
// If no alerts, return a single "no alerts" message
vec![json!({
"type": "text",
"text": "No Wazuh alerts found."
})]
} else {
// Map each alert to a content item
alerts_to_process
.into_iter()
.map(|alert| {
let source = alert.get("_source").unwrap_or(&alert);
// Extract alert ID
let id = source.get("id")
.and_then(|v| v.as_str())
.or_else(|| alert.get("_id").and_then(|v| v.as_str()))
.unwrap_or("Unknown ID");
// Extract rule description
let description = source.get("rule")
.and_then(|r| r.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("No description available");
// Extract timestamp if available
let timestamp = source.get("timestamp")
.and_then(|t| t.as_str())
.unwrap_or("Unknown time");
// Extract agent name if available
let agent_name = source.get("agent")
.and_then(|a| a.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("Unknown agent");
// Extract rule level if available
let rule_level = source.get("rule")
.and_then(|r| r.get("level"))
.and_then(|l| l.as_u64())
.unwrap_or(0);
// Format the alert as a text entry and create a content item
json!({
"type": "text",
"text": format!(
"Alert ID: {}\nTime: {}\nAgent: {}\nLevel: {}\nDescription: {}",
id, timestamp, agent_name, rule_level, description
)
})
})
.collect()
};
tracing::info!("Successfully processed {} alerts into content items", content_items.len());
if alerts_to_process.is_empty() {
tracing::info!("No Wazuh alerts found to process. Returning standard message.");
// Ensure this directly returns a Vec<Content> with one Content::text item
return Ok(CallToolResult::success(vec![Content::text(
"No Wazuh alerts found.",
)]));
}
// Construct the final result with the content array containing multiple text objects
let result = json!({
"content": content_items
});
Ok(CallToolResult::success(vec![
Content::json(result)
.map_err(|e| McpError::internal_error(e.to_string(), None))?,
]))
// Process non-empty alerts
// This part should already be correct if alerts_to_process is not empty,
// as it maps each alert to Content::text directly.
let num_alerts_to_process = alerts_to_process.len(); // Get length before moving
let mcp_content_items: Vec<Content> = alerts_to_process
.into_iter()
.map(|alert_value| {
let source = alert_value.get("_source").unwrap_or(&alert_value);
let id = source.get("id")
.and_then(|v| v.as_str())
.or_else(|| alert_value.get("_id").and_then(|v| v.as_str()))
.unwrap_or("Unknown ID");
let description = source.get("rule")
.and_then(|r| r.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("No description available");
let timestamp = source.get("timestamp")
.and_then(|t| t.as_str())
.unwrap_or("Unknown time");
let agent_name = source.get("agent")
.and_then(|a| a.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("Unknown agent");
let rule_level = source.get("rule")
.and_then(|r| r.get("level"))
.and_then(|l| l.as_u64())
.unwrap_or(0);
let formatted_text = format!(
"Alert ID: {}\nTime: {}\nAgent: {}\nLevel: {}\nDescription: {}",
id, timestamp, agent_name, rule_level, description
);
Content::text(formatted_text)
})
.collect();
tracing::info!("Successfully processed {} alerts into {} MCP content items", num_alerts_to_process, mcp_content_items.len());
Ok(CallToolResult::success(mcp_content_items))
}
Err(e) => {
let err_msg = format!("Error retrieving alerts from Wazuh: {}", e);

View File

@@ -1,425 +0,0 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use thiserror::Error;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
#[derive(Error, Debug)]
pub enum McpClientError {
#[error("HTTP request error: {0}")]
HttpRequestError(#[from] reqwest::Error),
#[error("HTTP API error: status {status}, message: {message}")]
HttpApiError {
status: reqwest::StatusCode,
message: String,
},
#[error("JSON serialization/deserialization error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Failed to spawn child process: {0}")]
ProcessSpawnError(String),
#[error("Child process stdin/stdout not available")]
ProcessPipeError,
#[error("JSON-RPC error: code {code}, message: {message}, data: {data:?}")]
JsonRpcError {
code: i32,
message: String,
data: Option<Value>,
},
#[error("Received unexpected JSON-RPC response: {0}")]
UnexpectedResponse(String),
#[error("Operation timed out")]
Timeout,
#[error("Operation not supported in current mode: {0}")]
UnsupportedOperation(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpMessage {
pub protocol_version: String,
pub source: String,
pub timestamp: String,
pub event_type: String,
pub context: Value,
pub metadata: Value,
}
#[derive(Serialize, Debug)]
struct JsonRpcRequest<T: Serialize> {
jsonrpc: String,
method: String,
params: Option<T>,
id: Value, // Changed from usize to Value
}
#[derive(Deserialize, Debug)]
struct JsonRpcResponse<T> {
jsonrpc: String,
result: Option<T>,
error: Option<JsonRpcErrorData>,
id: Value, // Changed from usize to Value
}
#[derive(Deserialize, Debug)]
struct JsonRpcErrorData {
code: i32,
message: String,
data: Option<Value>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
#[derive(Deserialize, Debug, Clone)]
pub struct InitializeResult {
pub protocol_version: String,
pub server_info: ServerInfo,
}
#[async_trait]
pub trait McpClientTrait {
async fn initialize(&mut self) -> Result<InitializeResult, McpClientError>;
async fn provide_context(
&mut self,
params: Option<Value>,
) -> Result<Vec<McpMessage>, McpClientError>;
async fn shutdown(&mut self) -> Result<(), McpClientError>;
}
enum ClientMode {
Http {
client: Client,
base_url: String,
},
Stdio {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
},
}
pub struct McpClient {
mode: ClientMode,
child_process: Option<Child>,
request_id_counter: AtomicUsize,
}
#[async_trait]
impl McpClientTrait for McpClient {
async fn initialize(&mut self) -> Result<InitializeResult, McpClientError> {
match &mut self.mode {
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
"initialize is not supported in HTTP mode".to_string(),
)),
ClientMode::Stdio { .. } => {
let request_id = self.next_id();
self.send_stdio_request("initialize", None::<()>, request_id)
.await
}
}
}
async fn provide_context(
&mut self,
params: Option<Value>,
) -> Result<Vec<McpMessage>, McpClientError> {
match &mut self.mode {
ClientMode::Http { client, base_url } => {
let url = format!("{}/mcp", base_url);
let request_builder = if let Some(p) = params {
client.post(&url).json(&p)
} else {
client.get(&url)
};
let response = request_builder
.send()
.await
.map_err(McpClientError::HttpRequestError)?;
if !response.status().is_success() {
let status = response.status();
let message = response.text().await.unwrap_or_else(|_| {
format!("Failed to get error body for status {}", status)
});
return Err(McpClientError::HttpApiError { status, message });
}
response
.json::<Vec<McpMessage>>()
.await
.map_err(McpClientError::HttpRequestError)
}
ClientMode::Stdio { .. } => {
let request_id = self.next_id();
self.send_stdio_request("provideContext", params, request_id)
.await
}
}
}
async fn shutdown(&mut self) -> Result<(), McpClientError> {
match &mut self.mode {
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
"shutdown is not supported in HTTP mode".to_string(),
)),
ClientMode::Stdio { .. } => {
let request_id = self.next_id();
// Attempt to send shutdown command, ignore error if server already closed pipe
let _result: Result<Option<Value>, McpClientError> = self
.send_stdio_request("shutdown", None::<()>, request_id)
.await;
// Always try to clean up the process
self.close_stdio_process().await
}
}
}
}
impl McpClient {
pub fn new_http(base_url: String) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client");
Self {
mode: ClientMode::Http { client, base_url },
child_process: None,
request_id_counter: AtomicUsize::new(1),
}
}
pub async fn new_stdio(
executable_path: &str,
envs: Option<Vec<(String, String)>>,
) -> Result<Self, McpClientError> {
let mut command = Command::new(executable_path);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::inherit()); // Pipe child's stderr to parent's stderr for visibility
if let Some(env_vars) = envs {
for (key, value) in env_vars {
command.env(key, value);
}
}
let mut child = command
.spawn()
.map_err(|e| McpClientError::ProcessSpawnError(e.to_string()))?;
let stdin = child.stdin.take().ok_or(McpClientError::ProcessPipeError)?;
let stdout = child
.stdout
.take()
.ok_or(McpClientError::ProcessPipeError)?;
Ok(Self {
mode: ClientMode::Stdio {
stdin,
stdout: BufReader::new(stdout),
},
child_process: Some(child),
request_id_counter: AtomicUsize::new(1),
})
}
fn next_id(&self) -> Value {
Value::from(self.request_id_counter.fetch_add(1, Ordering::SeqCst))
}
async fn send_stdio_request<P: Serialize, R: DeserializeOwned>(
&mut self,
method: &str,
params: Option<P>,
id: Value, // Added id parameter
) -> Result<R, McpClientError> {
// Removed: let request_id = self.next_id();
let rpc_request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params,
id: id.clone(), // Use the provided id
};
let request_json = serde_json::to_string(&rpc_request)? + "\n";
let (stdin, stdout) = match &mut self.mode {
ClientMode::Stdio { stdin, stdout } => (stdin, stdout),
ClientMode::Http { .. } => {
return Err(McpClientError::UnsupportedOperation(
"send_stdio_request is only for Stdio mode".to_string(),
))
}
};
stdin.write_all(request_json.as_bytes()).await?;
stdin.flush().await?;
let mut response_json = String::new();
match tokio::time::timeout(
Duration::from_secs(10),
stdout.read_line(&mut response_json),
)
.await
{
Ok(Ok(0)) => {
return Err(McpClientError::IoError(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Server closed stdout",
)))
}
Ok(Ok(_)) => { /* continue */ }
Ok(Err(e)) => return Err(McpClientError::IoError(e)),
Err(_) => return Err(McpClientError::Timeout),
}
let rpc_response: JsonRpcResponse<R> = serde_json::from_str(response_json.trim())?;
// Compare Value IDs. Note: Value implements PartialEq.
if rpc_response.id != id {
return Err(McpClientError::UnexpectedResponse(format!(
"Mismatched request/response IDs. Expected {}, got {}. Response: '{}'",
id, rpc_response.id, response_json
)));
}
if let Some(err_data) = rpc_response.error {
return Err(McpClientError::JsonRpcError {
code: err_data.code,
message: err_data.message,
data: err_data.data,
});
}
rpc_response.result.ok_or_else(|| {
McpClientError::UnexpectedResponse("Missing result in JSON-RPC response".to_string())
})
}
async fn close_stdio_process(&mut self) -> Result<(), McpClientError> {
if let Some(mut child) = self.child_process.take() {
child.kill().await.map_err(McpClientError::IoError)?;
let _ = child.wait().await; // Ensure process is reaped
}
Ok(())
}
// New public method for sending generic JSON-RPC requests
pub async fn send_json_rpc_request(
&mut self,
method: &str,
params: Option<Value>,
id: Value,
) -> Result<Value, McpClientError> {
match &mut self.mode {
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
"Generic JSON-RPC calls are not supported in HTTP mode by this client.".to_string(),
)),
ClientMode::Stdio { .. } => {
// R (result type) is Value for generic calls
self.send_stdio_request(method, params, id).await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::prelude::*;
use serde_json::json;
use tokio;
#[tokio::test]
async fn test_mcp_client_http_get_data() {
// Renamed to be specific
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/mcp");
then.status(200)
.header("content-type", "application/json")
.json_body(json!([
{
"protocol_version": "1.0",
"source": "Wazuh",
"timestamp": "2023-05-01T12:00:00Z",
"event_type": "alert",
"context": {
"id": "12345",
"category": "intrusion_detection",
"severity": "high",
"description": "Test alert",
"data": { "source_ip": "192.168.1.100" }
},
"metadata": { "integration": "Wazuh-MCP", "notes": "Test note" }
}
]));
});
let mut client = McpClient::new_http(server.url("")); // Use new_http
let result = client.provide_context(None).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].protocol_version, "1.0");
assert_eq!(result[0].source, "Wazuh");
assert_eq!(result[0].event_type, "alert");
let context = &result[0].context;
assert_eq!(context["id"], "12345");
assert_eq!(context["category"], "intrusion_detection");
assert_eq!(context["severity"], "high");
}
#[tokio::test]
async fn test_mcp_client_http_health_check_equivalent() {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/health"); // Assuming /health is still the target for this test
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"status": "ok",
"service": "wazuh-mcp-server",
"timestamp": "2023-05-01T12:00:00Z"
}));
});
let client = McpClient::new_http(server.url(""));
// The new trait doesn't have a direct "check_health".
// If `initialize` was to be used for HTTP health, it would be:
// let result = client.initialize().await;
// But initialize is Stdio-only. So this test needs to adapt or be removed
// if there's no direct equivalent in the new trait for HTTP health.
// For now, let's assume we might add a specific http_health method if needed,
// or this test is demonstrating a capability that's no longer directly on the trait.
// To make this test pass with current structure, we'd need a separate HTTP health method.
// Let's simulate calling the /health endpoint directly if that's the intent.
let http_client = reqwest::Client::new();
let response = http_client.get(server.url("/health")).send().await.unwrap();
assert_eq!(response.status(), reqwest::StatusCode::OK);
let health_data: Value = response.json().await.unwrap();
assert_eq!(health_data["status"], "ok");
assert_eq!(health_data["service"], "wazuh-mcp-server");
}
// TODO: Add tests for Stdio mode. This would require a mock executable
// or a more complex test setup. For now, focusing on the client structure.
}

View File

@@ -1,456 +0,0 @@
use serde_json::{json, Value};
use std::sync::Arc;
use tracing::{debug, error, info};
use crate::mcp::protocol::{error_codes, JsonRpcError, JsonRpcRequest, JsonRpcResponse};
use crate::AppState;
#[derive(serde::Deserialize, Debug)]
struct ToolCallParams {
#[serde(rename = "name")]
name: String,
#[serde(rename = "arguments")]
arguments: Option<Value>, // Input parameters for the specific tool
#[serde(flatten)]
_extra: std::collections::HashMap<String, Value>,
}
pub struct McpServerCore {
app_state: Arc<AppState>,
}
impl McpServerCore {
pub fn new(app_state: Arc<AppState>) -> Self {
Self { app_state }
}
pub async fn process_request(&self, request: JsonRpcRequest) -> String {
info!("Processing request: method={}", request.method);
let response = match request.method.as_str() {
"initialize" => self.handle_initialize(request).await,
"shutdown" => self.handle_shutdown(request).await,
"provideContext" => self.handle_provide_context(request).await,
// Tool methods (prefix "tools/")
"tools/list" => self.handle_list_tools(request).await,
"tools/call" => self.handle_tool_call(request).await, // Use generic tool call handler
// "tools/wazuhAlerts" => self.handle_wazuh_alerts_tool(request).await,
// Resource methods (prefix "resources/")
"resources/list" => self.handle_get_resources(request).await,
"resources/read" => self.handle_read_resource(request).await,
// Prompt methods (prefix "prompts/")
"prompts/list" => self.handle_list_prompts(request).await,
_ => {
error!("Method not found: {}", request.method);
self.create_error_response(
error_codes::METHOD_NOT_FOUND,
format!("Method '{}' not found", request.method),
None,
request.id.clone(),
)
}
};
response
}
pub fn handle_parse_error(&self, error: serde_json::Error, raw_request: &str) -> String {
error!("Failed to parse JSON-RPC request: {}", error);
// Try to extract the ID from the raw request if possible
let id = serde_json::from_str::<Value>(raw_request)
.and_then(|v| {
if let Some(id) = v.get("id") {
Ok(id.clone())
} else {
// Use a different approach since custom is not available
Err(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"No ID field found",
)))
}
})
.unwrap_or(Value::Null);
self.create_error_response(
error_codes::PARSE_ERROR,
format!("Parse error: {}", error),
None,
id,
)
}
async fn handle_initialize(&self, request: JsonRpcRequest) -> String {
debug!("Handling initialize request");
let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition {
name: "wazuhAlertSummary".to_string(),
description: Some("Returns a text summary of all Wazuh alerts.".to_string()),
// Define a minimal valid input schema (empty object)
input_schema: Some(json!({
"type": "object",
"properties": {}
})),
// No output schema needed as per requirements
output_schema: None,
};
let result = crate::mcp::protocol::InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: crate::mcp::protocol::Capabilities {
tools: crate::mcp::protocol::ToolCapability {
supported: true,
definitions: vec![wazuh_alert_summary_tool], // Only include wazuhAlertSummary tool
},
resources: crate::mcp::protocol::SupportedFeature { supported: true },
prompts: crate::mcp::protocol::SupportedFeature { supported: true },
},
server_info: crate::mcp::protocol::ServerInfo {
name: "Wazuh MCP Server".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
self.create_success_response(result, request.id)
}
async fn handle_shutdown(&self, request: JsonRpcRequest) -> String {
debug!("Handling shutdown request");
self.create_success_response(Value::Null, request.id)
}
async fn handle_provide_context(&self, request: JsonRpcRequest) -> String {
debug!("Handling provideContext request");
let mut wazuh_client = self.app_state.wazuh_client.lock().await;
match wazuh_client.get_alerts().await {
Ok(alerts) => {
let mcp_messages: Vec<Value> = alerts
.into_iter()
.map(|alert| crate::mcp::transform::transform_to_mcp(alert, "alert".to_string()))
.collect();
debug!("Transformed {} alerts into MCP messages for provideContext", mcp_messages.len());
self.create_success_response(json!(mcp_messages), request.id)
}
Err(e) => {
error!("Error getting alerts from Wazuh for provideContext: {}", e);
self.create_error_response(
error_codes::INTERNAL_ERROR,
format!("Failed to get alerts from Wazuh: {}", e),
None,
request.id,
)
}
}
}
async fn handle_get_resources(&self, request: JsonRpcRequest) -> String {
debug!("Handling getResources request");
// Return an empty list for now
let resources_result = crate::mcp::protocol::ResourcesListResult {
resources: vec![],
};
self.create_success_response(resources_result, request.id)
}
async fn handle_read_resource(&self, request: JsonRpcRequest) -> String {
debug!("Handling readResource request: {:?}", request.params);
#[derive(serde::Deserialize, Debug)]
struct ReadResourceParams {
uri: String,
// We can add _meta here if needed later
// _meta: Option<Value>,
}
let params: ReadResourceParams = match request.params {
Some(params_value) => match serde_json::from_value(params_value) {
Ok(p) => p,
Err(e) => {
error!("Failed to parse params for resources/read: {}", e);
return self.create_error_response(
error_codes::INVALID_PARAMS,
format!("Invalid params for resources/read: {}", e),
None,
request.id,
);
}
},
None => {
error!("Missing params for resources/read");
return self.create_error_response(
error_codes::INVALID_PARAMS,
"Missing params for resources/read, 'uri' is required".to_string(),
None,
request.id,
);
}
};
error!("Unsupported URI for resources/read: {}", params.uri);
self.create_error_response(
error_codes::INVALID_PARAMS,
format!("Unsupported or unknown resource URI: {}", params.uri),
None,
request.id,
)
}
async fn handle_tool_call(&self, request: JsonRpcRequest) -> String {
debug!("Handling tools/call request: {:?}", request.params);
let params: ToolCallParams = match request.clone().params {
Some(params_value) => match serde_json::from_value(params_value) {
Ok(p) => p,
Err(e) => {
error!("Failed to parse params for tools/call: {}", e);
return self.create_error_response(
error_codes::INVALID_PARAMS,
format!("Invalid params for tools/call: {}", e),
None,
request.id,
);
}
},
None => {
error!("Missing params for tools/call");
return self.create_error_response(
error_codes::INVALID_PARAMS,
"Missing params for tools/call, 'name' and 'arguments' are required".to_string(),
None,
request.id,
);
}
};
match params.name.as_str() {
"wazuhAlertSummary" => {
info!("Dispatching tools/call to wazuhAlertSummary handler");
self.handle_wazuh_alert_summary_tool(request).await
}
// wazuhAlerts tool is disabled but we keep the handler code
_ => {
error!("Unsupported tool name requested via tools/call: {}", params.name);
self.create_error_response(
error_codes::METHOD_NOT_FOUND, // Or a more specific tool error code if available
format!("Tool '{}' not found", params.name),
None,
request.id,
)
}
}
}
async fn handle_list_tools(&self, request: JsonRpcRequest) -> String {
debug!("Handling tools/list request");
// Define the wazuhAlertSummary tool
let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition {
name: "wazuhAlertSummary".to_string(),
description: Some("Returns a text summary of all Wazuh alerts.".to_string()),
// Define a minimal valid input schema (empty object)
input_schema: Some(json!({
"type": "object",
"properties": {}
})),
// No output schema needed as per requirements
output_schema: None,
};
let tools_list = crate::mcp::protocol::ToolsListResult {
tools: vec![wazuh_alert_summary_tool],
};
self.create_success_response(tools_list, request.id)
}
async fn handle_wazuh_alerts_tool(&self, request: JsonRpcRequest) -> String {
debug!("Handling tools/wazuhAlerts request. Params: {:?}", request.params);
let wazuh_client = self.app_state.wazuh_client.lock().await;
match wazuh_client.get_alerts().await {
Ok(raw_alerts) => {
let simplified_alerts: Vec<Value> = raw_alerts
.into_iter()
.map(|alert| {
let source = alert.get("_source").unwrap_or(&alert);
// Extract ID: Try _source.id first, then _id
let id = source.get("id")
.and_then(|v| v.as_str())
.or_else(|| alert.get("_id").and_then(|v| v.as_str()))
.unwrap_or("") // Default to empty string if not found
.to_string();
// Extract Description: Look in _source.rule.description
let description = source.get("rule")
.and_then(|r| r.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("") // Default to empty string if not found
.to_string();
json!({
"id": id,
"description": description,
})
})
.collect();
debug!("Processed {} alerts into simplified format.", simplified_alerts.len());
// Construct the final result with the "alerts" array
let result = json!({
"alerts": simplified_alerts,
"text": "Hello World",
});
self.create_success_response(result, request.id)
}
Err(e) => {
error!("Error getting alerts from Wazuh for tools/wazuhAlerts: {}", e);
self.create_error_response(
error_codes::INTERNAL_ERROR,
format!("Failed to get alerts from Wazuh: {}", e),
None,
request.id,
)
}
}
}
async fn handle_wazuh_alert_summary_tool(&self, request: JsonRpcRequest) -> String {
debug!("Handling tools/wazuhAlertSummary request. Params: {:?}", request.params);
let mut wazuh_client = self.app_state.wazuh_client.lock().await;
match wazuh_client.get_alerts().await {
Ok(raw_alerts) => {
// Create a content item for each alert
let content_items: Vec<Value> = if raw_alerts.is_empty() {
// If no alerts, return a single "no alerts" message
vec![json!({
"type": "text",
"text": "No Wazuh alerts found."
})]
} else {
// Map each alert to a content item
raw_alerts
.into_iter()
.map(|alert| {
let source = alert.get("_source").unwrap_or(&alert);
// Extract alert ID
let id = source.get("id")
.and_then(|v| v.as_str())
.or_else(|| alert.get("_id").and_then(|v| v.as_str()))
.unwrap_or("Unknown ID");
// Extract rule description
let description = source.get("rule")
.and_then(|r| r.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("No description available");
// Extract timestamp if available
let timestamp = source.get("timestamp")
.and_then(|t| t.as_str())
.unwrap_or("Unknown time");
// Format the alert as a text entry and create a content item
json!({
"type": "text",
"text": format!("Alert ID: {}\nTime: {}\nDescription: {}", id, timestamp, description)
})
})
.collect()
};
debug!("Processed {} alerts into individual content items.", content_items.len());
// Construct the final result with the content array containing multiple text objects
let result = json!({
"content": content_items
});
self.create_success_response(result, request.id)
}
Err(e) => {
error!("Error getting alerts from Wazuh for tools/wazuhAlertSummary: {}", e);
self.create_error_response(
error_codes::INTERNAL_ERROR,
format!("Failed to get alerts from Wazuh: {}", e),
None,
request.id,
)
}
}
}
async fn handle_list_prompts(&self, request: JsonRpcRequest) -> String {
debug!("Handling prompts/list request");
// Define the single prompt according to the new structure
let list_alerts_prompt = crate::mcp::protocol::PromptEntry {
name: "list-wazuh-alerts".to_string(),
description: Some("List the latest security alerts from Wazuh.".to_string()),
arguments: vec![], // This prompt takes no arguments
};
let prompts = vec![list_alerts_prompt];
let result = crate::mcp::protocol::PromptsListResult { prompts };
self.create_success_response(result, request.id)
}
fn create_success_response<T: serde::Serialize>(&self, result: T, id: Value) -> String {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: Some(result),
error: None,
id,
};
serde_json::to_string(&response).unwrap_or_else(|e| {
error!("Failed to serialize JSON-RPC response: {}", e);
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize response"}},"id":null}}"#
)
})
}
pub(crate) fn create_error_response(
&self,
code: i32,
message: String,
data: Option<Value>,
id: Value,
) -> String {
let response = JsonRpcResponse::<Value> {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(JsonRpcError {
code,
message,
data,
}),
id,
};
serde_json::to_string(&response).unwrap_or_else(|e| {
error!("Failed to serialize JSON-RPC error response: {}", e);
format!(
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize error response"}},"id":null}}"#
)
})
}
}

View File

@@ -1,4 +0,0 @@
pub mod transform;
pub mod client;
pub mod protocol;
pub mod mcp_server_core;

View File

@@ -1,127 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Deserialize, Debug, Clone)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub method: String,
pub params: Option<Value>,
pub id: Value,
}
#[derive(Serialize, Debug)]
pub struct JsonRpcResponse<T: Serialize> {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<T>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
pub id: Value,
}
#[derive(Serialize, Debug)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
#[derive(Serialize, Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(rename = "inputSchema", skip_serializing_if = "Option::is_none")]
pub input_schema: Option<Value>, // Added inputSchema
#[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")]
pub output_schema: Option<Value>, // Added outputSchema
}
#[derive(Serialize, Debug)]
pub struct ToolCapability {
pub supported: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub definitions: Vec<ToolDefinition>, // List available tools
}
#[derive(Serialize, Debug)]
pub struct SupportedFeature {
pub supported: bool,
}
#[derive(Serialize, Debug)]
pub struct Capabilities {
pub tools: ToolCapability, // Use the new structure
pub resources: SupportedFeature,
pub prompts: SupportedFeature,
}
#[derive(Serialize, Debug)]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
#[derive(Serialize, Debug)]
pub struct InitializeResult {
#[serde(rename = "protocolVersion")]
pub protocol_version: String,
pub capabilities: Capabilities,
#[serde(rename = "serverInfo")]
pub server_info: ServerInfo,
}
#[derive(Serialize, Debug)]
pub struct ResourceEntry {
pub uri: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Serialize, Debug)]
pub struct ResourcesListResult {
pub resources: Vec<ResourceEntry>,
}
#[derive(Serialize, Debug)]
pub struct ToolsListResult {
pub tools: Vec<ToolDefinition>,
}
#[derive(Serialize, Debug, Clone)]
pub struct PromptArgument {
pub name: String,
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<Value>, // Use Value for flexibility (string, bool, number)
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>, // Optional description for the argument
}
#[derive(Serialize, Debug, Clone)]
pub struct PromptEntry {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub arguments: Vec<PromptArgument>,
}
#[derive(Serialize, Debug)]
pub struct PromptsListResult {
pub prompts: Vec<PromptEntry>,
}
pub mod error_codes {
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
pub const SERVER_ERROR_START: i32 = -32000;
pub const SERVER_ERROR_END: i32 = -32099;
}

View File

@@ -1,235 +0,0 @@
use chrono::{DateTime, Utc, SecondsFormat};
use serde_json::{json, Value};
use tracing::{debug, warn};
pub fn transform_to_mcp(event: Value, event_type: String) -> Value {
debug!(?event, %event_type, "Entering transform_to_mcp");
let source_obj = event.get("_source").unwrap_or(&event);
if event.get("_source").is_some() {
debug!("Event contains '_source' field, using it for transformation.");
} else {
debug!("Event does not contain '_source' field, using the event root for transformation.");
}
let id = source_obj.get("id")
.and_then(|v| v.as_str())
.or_else(|| event.get("_id").and_then(|v| v.as_str()))
.unwrap_or("unknown_id")
.to_string();
debug!(%id, "Transformed: id");
let default_rule = json!({});
let rule = source_obj.get("rule").unwrap_or(&default_rule);
if source_obj.get("rule").is_none() {
debug!("Transformed: rule (defaulted to empty object)");
} else {
debug!(?rule, "Transformed: rule");
}
let category = rule.get("groups")
.and_then(|g| g.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.as_str())
.unwrap_or("unknown_category")
.to_string();
debug!(%category, "Transformed: category");
let severity_level = rule.get("level").and_then(|v| v.as_u64());
let severity = severity_level
.map(|level| match level {
0..=3 => "low",
4..=7 => "medium",
8..=11 => "high",
_ => "critical",
})
.unwrap_or("unknown_severity")
.to_string();
debug!(?severity_level, %severity, "Transformed: severity");
let description = rule.get("description")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
debug!(%description, "Transformed: description");
let default_data = json!({});
let data = source_obj.get("data").cloned().unwrap_or_else(|| {
debug!("Transformed: data (defaulted to empty object)");
default_data.clone()
});
if source_obj.get("data").is_some() {
debug!(?data, "Transformed: data");
}
let default_agent = json!({});
let agent = source_obj.get("agent").cloned().unwrap_or_else(|| {
debug!("Transformed: agent (defaulted to empty object)");
default_agent.clone()
});
if source_obj.get("agent").is_some() {
debug!(?agent, "Transformed: agent");
}
let timestamp_str = source_obj.get("timestamp")
.and_then(|v| v.as_str())
.unwrap_or("");
debug!(%timestamp_str, "Attempting to parse timestamp");
let timestamp = DateTime::parse_from_rfc3339(timestamp_str)
.map(|dt| dt.with_timezone(&Utc))
.or_else(|_| DateTime::parse_from_str(timestamp_str, "%Y-%m-%dT%H:%M:%S%.fZ").map(|dt| dt.with_timezone(&Utc)))
.unwrap_or_else(|_| {
warn!("Failed to parse timestamp '{}' for alert ID '{}'. Using current time.", timestamp_str, id);
Utc::now()
});
debug!(%timestamp, "Transformed: timestamp");
let notes = "Data fetched via Wazuh API".to_string();
debug!(%notes, "Transformed: notes");
let mcp_message = json!({
"protocolVersion": "1.0", // Match initialize response
"source": "Wazuh",
"timestamp": timestamp.to_rfc3339_opts(SecondsFormat::Secs, true),
"event_type": event_type,
"context": {
"id": id,
"category": category,
"severity": severity,
"description": description,
"agent": agent,
"data": data
},
"metadata": {
"integration": "Wazuh-MCP",
"notes": notes
}
});
debug!(?mcp_message, "Exiting transform_to_mcp with result");
mcp_message
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
use serde_json::json;
#[test]
fn test_transform_to_mcp_basic() {
let event_time_str = "2023-10-27T10:30:00.123Z";
let event_time = Utc.datetime_from_str(event_time_str, "%Y-%m-%dT%H:%M:%S%.fZ").unwrap();
let event = json!({
"id": "alert1",
"_id": "wazuh_alert_id_1",
"timestamp": event_time_str,
"rule": {
"level": 10,
"description": "High severity rule triggered",
"id": "1002",
"groups": ["gdpr", "pci_dss", "intrusion_detection"]
},
"agent": {
"id": "001",
"name": "server-db"
},
"data": {
"srcip": "1.2.3.4",
"dstport": "22"
}
});
let result = transform_to_mcp(event.clone(), "alert".to_string());
assert_eq!(result["protocol_version"], "1.0");
assert_eq!(result["source"], "Wazuh");
assert_eq!(result["event_type"], "alert");
assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true));
let context = &result["context"];
assert_eq!(context["id"], "alert1");
assert_eq!(context["category"], "gdpr");
assert_eq!(context["severity"], "high");
assert_eq!(context["description"], "High severity rule triggered");
assert_eq!(context["agent"]["name"], "server-db");
assert_eq!(context["data"]["srcip"], "1.2.3.4");
let metadata = &result["metadata"];
assert_eq!(metadata["integration"], "Wazuh-MCP");
assert_eq!(metadata["notes"], "Data fetched via Wazuh API");
}
#[test]
fn test_transform_to_mcp_with_source_nesting() {
let event_time_str = "2023-10-27T11:00:00Z";
let event_time = DateTime::parse_from_rfc3339(event_time_str).unwrap().with_timezone(&Utc);
let event = json!({
"_index": "wazuh-alerts-4.x-2023.10.27",
"_id": "alert_source_nested",
"_source": {
"id": "nested_alert_id",
"timestamp": event_time_str,
"rule": {
"level": 5,
"description": "Medium severity rule",
"groups": ["system_audit"]
},
"agent": { "id": "002", "name": "web-server" },
"data": { "command": "useradd test" }
}
});
let result = transform_to_mcp(event.clone(), "alert".to_string());
assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true));
let context = &result["context"];
assert_eq!(context["id"], "nested_alert_id");
assert_eq!(context["category"], "system_audit");
assert_eq!(context["severity"], "medium");
assert_eq!(context["description"], "Medium severity rule");
assert_eq!(context["agent"]["name"], "web-server");
assert_eq!(context["data"]["command"], "useradd test");
}
#[test]
fn test_transform_to_mcp_with_defaults() {
let event = json!({});
let before_transform = Utc::now();
let result = transform_to_mcp(event, "alert".to_string());
let after_transform = Utc::now();
assert_eq!(result["context"]["id"], "unknown_id");
assert_eq!(result["context"]["category"], "unknown_category");
assert_eq!(result["context"]["severity"], "unknown_severity");
assert_eq!(result["context"]["description"], "");
assert!(result["context"]["data"].is_object());
assert!(result["context"]["agent"].is_object());
assert_eq!(result["metadata"]["notes"], "Data fetched via Wazuh API");
let result_ts_str = result["timestamp"].as_str().unwrap();
let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc);
assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp());
}
#[test]
fn test_transform_timestamp_parsing_fallback() {
let event = json!({
"id": "ts_test",
"timestamp": "invalid-timestamp-format",
"rule": { "level": 3 },
});
let before_transform = Utc::now();
let result = transform_to_mcp(event, "alert".to_string());
let after_transform = Utc::now();
let result_ts_str = result["timestamp"].as_str().unwrap();
let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc);
assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp());
assert_eq!(result["context"]["id"], "ts_test");
assert_eq!(result["context"]["severity"], "low");
}
}

View File

@@ -1,186 +0,0 @@
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::oneshot::Sender as OneshotSender;
use tracing::{debug, error, info};
use crate::logging_utils::{log_mcp_request, log_mcp_response};
use crate::mcp::mcp_server_core::McpServerCore;
use crate::mcp::protocol::{error_codes, JsonRpcRequest};
use crate::AppState;
use serde_json::Value;
pub async fn run_stdio_service(app_state: Arc<AppState>, shutdown_tx: OneshotSender<()>) {
info!("Starting MCP server in stdio mode...");
let mut stdin_reader = BufReader::new(tokio::io::stdin());
let mut stdout_writer = tokio::io::stdout();
let mcp_core = McpServerCore::new(app_state);
let mut line_buffer = String::new();
debug!("run_stdio_service: Initialized readers/writers. Entering main loop.");
loop {
debug!("stdio_service: Top of the loop. Clearing line buffer.");
line_buffer.clear();
debug!("stdio_service: About to read_line from stdin.");
let read_result = stdin_reader.read_line(&mut line_buffer).await;
debug!(?read_result, "stdio_service: read_line completed.");
match read_result {
Ok(0) => {
debug!("stdio_service: read_line returned Ok(0) (EOF).");
info!("Stdin closed (EOF), signaling shutdown and exiting stdio mode.");
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
debug!("stdio_service read 0 bytes, breaking loop.");
break; // EOF
}
Ok(bytes_read) => {
debug!(%bytes_read, "stdio_service: read_line returned Ok(bytes_read).");
let request_str = line_buffer.trim();
if request_str.is_empty() {
debug!("Received empty line from stdin, continuing.");
continue;
}
info!("Received from stdin (stdio_service): {}", request_str);
log_mcp_request(request_str); // Log the raw incoming string
let parsed_value: Value = match serde_json::from_str(request_str) {
Ok(v) => v,
Err(e) => {
error!("JSON Parse Error: {}", e);
let response_json = mcp_core.handle_parse_error(e, request_str);
log_mcp_response(&response_json);
info!("Sending parse error response to stdout: {}", response_json);
let response_to_send = format!("{}\n", response_json);
if let Err(write_err) =
stdout_writer.write_all(response_to_send.as_bytes()).await
{
error!(
"Error writing parse error response to stdout: {}",
write_err
);
let _ = shutdown_tx.send(());
break;
}
if let Err(flush_err) = stdout_writer.flush().await {
error!("Error flushing stdout for parse error: {}", flush_err);
let _ = shutdown_tx.send(());
break;
}
continue;
}
};
if parsed_value.get("id").is_none()
|| parsed_value.get("id").map_or(false, |id| id.is_null())
{
// --- Handle Notification (No ID or ID is null) ---
let method = parsed_value
.get("method")
.and_then(Value::as_str)
.unwrap_or("");
info!("Received Notification: method='{}'", method);
match method {
"notifications/initialized" => {
debug!("Client 'initialized' notification received. No action taken, no response sent.");
}
"exit" => {
info!("'exit' notification received. Signaling shutdown immediately.");
let _ = shutdown_tx.send(());
return;
}
_ => {
debug!(
"Received unknown/unhandled notification method: '{}'. Ignoring.",
method
);
}
}
continue;
} else {
let request_id = parsed_value.get("id").cloned().unwrap(); // We know ID exists and is not null here
match serde_json::from_value::<JsonRpcRequest>(parsed_value) {
Ok(rpc_request) => {
// --- Successfully parsed a Request ---
let is_shutdown = rpc_request.method == "shutdown";
let response_json = mcp_core.process_request(rpc_request).await;
// Log and send the response
log_mcp_response(&response_json);
info!("Sending response to stdout: {}", response_json);
let response_to_send = format!("{}\n", response_json);
if let Err(e) =
stdout_writer.write_all(response_to_send.as_bytes()).await
{
error!("Error writing response to stdout: {}", e);
let _ = shutdown_tx.send(());
break;
}
if let Err(e) = stdout_writer.flush().await {
error!("Error flushing stdout: {}", e);
let _ = shutdown_tx.send(());
break;
}
// Handle shutdown *after* sending the response
if is_shutdown {
debug!("'shutdown' request processed successfully. Signaling shutdown.");
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
return; // Exit the service loop
}
}
Err(e) => {
error!("Invalid JSON-RPC Request structure: {}", e);
// Use the ID we extracted earlier
let response_json = mcp_core.create_error_response(
error_codes::INVALID_REQUEST,
format!("Invalid Request structure: {}", e),
None,
request_id, // Use the ID from the original request
);
log_mcp_response(&response_json);
info!(
"Sending invalid request error response to stdout: {}",
response_json
);
let response_to_send = format!("{}\n", response_json);
if let Err(write_err) =
stdout_writer.write_all(response_to_send.as_bytes()).await
{
error!(
"Error writing invalid request error response to stdout: {}",
write_err
);
let _ = shutdown_tx.send(());
break;
}
if let Err(flush_err) = stdout_writer.flush().await {
error!(
"Error flushing stdout for invalid request error: {}",
flush_err
);
let _ = shutdown_tx.send(());
break;
}
}
}
}
}
Err(e) => {
debug!(error = %e, "stdio_service: read_line returned Err.");
error!("Error reading from stdin for stdio_service: {}", e);
debug!("Signaling shutdown and breaking loop due to stdin read error.");
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
break;
}
}
debug!("stdio_service: Bottom of the loop, before next iteration.");
}
info!("run_stdio_service: Exited main loop. stdio_service task is finishing.");
}

View File

@@ -13,8 +13,8 @@ pub struct WazuhIndexerClient {
http_client: Client,
}
// Renamed impl block
impl WazuhIndexerClient {
#[allow(dead_code)]
pub fn new(
host: String,
indexer_port: u16,
@@ -22,9 +22,20 @@ impl WazuhIndexerClient {
password: String,
verify_ssl: bool,
) -> Self {
debug!(%host, indexer_port, %username, %verify_ssl, "Creating new WazuhIndexerClient");
Self::new_with_protocol(host, indexer_port, username, password, verify_ssl, "https")
}
pub fn new_with_protocol(
host: String,
indexer_port: u16,
username: String,
password: String,
verify_ssl: bool,
protocol: &str,
) -> Self {
debug!(%host, indexer_port, %username, %verify_ssl, %protocol, "Creating new WazuhIndexerClient");
// Base URL now points to the Indexer
let base_url = format!("https://{}:{}", host, indexer_port);
let base_url = format!("{}://{}:{}", protocol, host, indexer_port);
debug!(%base_url, "Wazuh Indexer base URL set");
let http_client = Client::builder()