From 5f452bac9d58820aac563d3c0c7611861ded08f2 Mon Sep 17 00:00:00 2001 From: Gianluca Brigandi Date: Thu, 10 Jul 2025 14:56:37 -0700 Subject: [PATCH] feat: Refactor tools and upgrade wazuh-client This commit introduces a major refactoring of the tool implementation by splitting the tools into separate modules based on their domain (agents, alerts, rules, stats, vulnerabilities). This improves modularity and maintainability. Key changes: - Upgraded wazuh-client to version 0.1.7 to leverage the new builder pattern for client instantiation. - Refactored the main WazuhToolsServer to delegate tool calls to the new domain-specific tool modules. - Created a tools module with submodules for each domain, each containing the relevant tool implementations and parameter structs. - Updated the default limit for most tools from 100 to 300, while the vulnerability summary limit is set to 10,000 to ensure comprehensive scans. - Removed a problematic manual test from the test script that was causing it to hang. --- Cargo.toml | 2 +- src/main.rs | 26 ++++++++-------- src/tools/agents.rs | 12 ++++---- src/tools/alerts.rs | 14 ++++----- src/tools/mod.rs | 9 ++++-- src/tools/rules.rs | 4 +-- src/tools/stats.rs | 8 ++--- src/tools/vulnerabilities.rs | 59 +++++++++++------------------------- tests/run_tests.sh | 41 +++---------------------- 9 files changed, 61 insertions(+), 114 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 287a9ed..0484507 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ repository = "https://github.com/gbrigandi/mcp-server-wazuh" readme = "README.md" [dependencies] -wazuh-client = "0.1.6" +wazuh-client = "0.1.7" rmcp = { version = "0.1.5", features = ["server", "transport-io"] } tokio = { version = "1", features = ["full"] } reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } diff --git a/src/main.rs b/src/main.rs index fcbbcbe..eb30281 100644 --- a/src/main.rs +++ b/src/main.rs @@ -171,18 +171,20 @@ impl WazuhToolsServer { .ok() .or_else(|| Some("https".to_string())); - let wazuh_factory = WazuhClientFactory::new( - api_host, - api_port, - api_username, - api_password, - indexer_host, - indexer_port, - indexer_username, - indexer_password, - verify_ssl, - test_protocol, - ); + let mut builder = WazuhClientFactory::builder() + .api_host(api_host) + .api_port(api_port) + .api_credentials(&api_username, &api_password) + .indexer_host(indexer_host) + .indexer_port(indexer_port) + .indexer_credentials(&indexer_username, &indexer_password) + .verify_ssl(verify_ssl); + + if let Some(protocol) = test_protocol { + builder = builder.protocol(protocol); + } + + let wazuh_factory = builder.build(); let wazuh_indexer_client = wazuh_factory.create_indexer_client(); let wazuh_rules_client = wazuh_factory.create_rules_client(); diff --git a/src/tools/agents.rs b/src/tools/agents.rs index a8f508c..5d7ad4c 100644 --- a/src/tools/agents.rs +++ b/src/tools/agents.rs @@ -15,7 +15,7 @@ use wazuh_client::{AgentsClient, Port as WazuhPort, VulnerabilityClient}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct GetAgentsParams { - #[schemars(description = "Maximum number of agents to retrieve (default: 100)")] + #[schemars(description = "Maximum number of agents to retrieve (default: 300)")] pub limit: Option, #[schemars( description = "Agent status filter (active, disconnected, pending, never_connected)" @@ -39,7 +39,7 @@ pub struct GetAgentProcessesParams { description = "Agent ID to get processes for (required, e.g., \"0\", \"1\", \"001\")" )] pub agent_id: String, - #[schemars(description = "Maximum number of processes to retrieve (default: 100)")] + #[schemars(description = "Maximum number of processes to retrieve (default: 300)")] pub limit: Option, #[schemars(description = "Search string to filter processes by name or command (optional)")] pub search: Option, @@ -51,7 +51,7 @@ pub struct GetAgentPortsParams { description = "Agent ID to get network ports for (required, e.g., \"001\", \"002\", \"003\")" )] pub agent_id: String, - #[schemars(description = "Maximum number of ports to retrieve (default: 100)")] + #[schemars(description = "Maximum number of ports to retrieve (default: 300)")] pub limit: Option, #[schemars(description = "Protocol to filter by (e.g., \"tcp\", \"udp\")")] pub protocol: String, @@ -80,7 +80,7 @@ impl AgentTools { &self, params: GetAgentsParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); tracing::info!( limit = %limit, @@ -277,7 +277,7 @@ impl AgentTools { return Self::error_result(err_msg); } }; - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); let offset = 0; tracing::info!( @@ -401,7 +401,7 @@ impl AgentTools { return Self::error_result(err_msg); } }; - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); let offset = 0; // Default offset tracing::info!( diff --git a/src/tools/alerts.rs b/src/tools/alerts.rs index 4c4b0dc..2212f5a 100644 --- a/src/tools/alerts.rs +++ b/src/tools/alerts.rs @@ -15,7 +15,7 @@ use super::ToolModule; /// Parameters for getting alert summary #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct GetAlertSummaryParams { - #[schemars(description = "Maximum number of alerts to retrieve (default: 100)")] + #[schemars(description = "Maximum number of alerts to retrieve (default: 300)")] pub limit: Option, } @@ -38,21 +38,19 @@ impl AlertTools { &self, #[tool(aggr)] params: GetAlertSummaryParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); tracing::info!(limit = %limit, "Retrieving Wazuh alert summary"); - match self.indexer_client.get_alerts().await { + match self.indexer_client.get_alerts(Some(limit)).await { Ok(raw_alerts) => { - let alerts_to_process: Vec<_> = raw_alerts.into_iter().take(limit as usize).collect(); - - if alerts_to_process.is_empty() { + if raw_alerts.is_empty() { tracing::info!("No Wazuh alerts found to process. Returning standard message."); return Self::not_found_result("Wazuh alerts"); } - let num_alerts_to_process = alerts_to_process.len(); - let mcp_content_items: Vec = alerts_to_process + let num_alerts_to_process = raw_alerts.len(); + let mcp_content_items: Vec = raw_alerts .into_iter() .map(|alert_value| { let source = alert_value.get("_source").unwrap_or(&alert_value); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 17359e6..0734ec3 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -26,9 +26,12 @@ pub trait ToolModule { } fn not_found_result(resource: &str) -> Result { - Ok(CallToolResult::success(vec![Content::text(format!( - "No {} found matching the specified criteria.", resource - ))])) + let message = if resource == "Wazuh alerts" { + "No Wazuh alerts found.".to_string() + } else { + format!("No {} found matching the specified criteria.", resource) + }; + Ok(CallToolResult::success(vec![Content::text(message)])) } } diff --git a/src/tools/rules.rs b/src/tools/rules.rs index 3094375..9910ca1 100644 --- a/src/tools/rules.rs +++ b/src/tools/rules.rs @@ -15,7 +15,7 @@ use super::ToolModule; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct GetRulesSummaryParams { - #[schemars(description = "Maximum number of rules to retrieve (default: 100)")] + #[schemars(description = "Maximum number of rules to retrieve (default: 300)")] pub limit: Option, #[schemars(description = "Rule level to filter by (optional)")] pub level: Option, @@ -43,7 +43,7 @@ impl RuleTools { &self, #[tool(aggr)] params: GetRulesSummaryParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); tracing::info!( limit = %limit, diff --git a/src/tools/stats.rs b/src/tools/stats.rs index 797d955..c84d435 100644 --- a/src/tools/stats.rs +++ b/src/tools/stats.rs @@ -13,7 +13,7 @@ use wazuh_client::{ClusterClient, LogsClient}; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct SearchManagerLogsParams { - #[schemars(description = "Maximum number of log entries to retrieve (default: 100)")] + #[schemars(description = "Maximum number of log entries to retrieve (default: 300)")] pub limit: Option, #[schemars(description = "Number of log entries to skip (default: 0)")] pub offset: Option, @@ -27,7 +27,7 @@ pub struct SearchManagerLogsParams { #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct GetManagerErrorLogsParams { - #[schemars(description = "Maximum number of error log entries to retrieve (default: 100)")] + #[schemars(description = "Maximum number of error log entries to retrieve (default: 300)")] pub limit: Option, } @@ -87,7 +87,7 @@ impl StatsTools { &self, params: SearchManagerLogsParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); let offset = params.offset.unwrap_or(0); tracing::info!( @@ -151,7 +151,7 @@ impl StatsTools { &self, params: GetManagerErrorLogsParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(300); tracing::info!(limit = %limit, "Retrieving Wazuh manager error logs"); diff --git a/src/tools/vulnerabilities.rs b/src/tools/vulnerabilities.rs index 5ec8da2..cb78e90 100644 --- a/src/tools/vulnerabilities.rs +++ b/src/tools/vulnerabilities.rs @@ -15,7 +15,7 @@ use super::ToolModule; #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct GetVulnerabilitiesSummaryParams { - #[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 100)")] + #[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 10000)")] pub limit: Option, #[schemars(description = "Agent ID to filter vulnerabilities by (required, e.g., \"0\", \"1\", \"001\")")] pub agent_id: String, @@ -29,6 +29,8 @@ pub struct GetVulnerabilitiesSummaryParams { pub struct GetCriticalVulnerabilitiesParams { #[schemars(description = "Agent ID to get critical vulnerabilities for (required, e.g., \"0\", \"1\", \"001\")")] pub agent_id: String, + #[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 300)")] + pub limit: Option, } #[derive(Clone)] @@ -67,7 +69,7 @@ impl VulnerabilityTools { &self, params: GetVulnerabilitiesSummaryParams, ) -> Result { - let limit = params.limit.unwrap_or(100); + let limit = params.limit.unwrap_or(10000); let offset = 0; // Default offset, can be extended in future if needed let agent_id = match Self::format_agent_id(¶ms.agent_id) { @@ -91,42 +93,17 @@ impl VulnerabilityTools { let mut vulnerability_client = self.vulnerability_client.lock().await; - let vulnerabilities = if let Some(cve_filter) = ¶ms.cve { - match vulnerability_client - .get_agent_vulnerabilities( - &agent_id, - Some(1000), // Get more results to filter - Some(offset), - params - .severity - .as_deref() - .and_then(VulnerabilitySeverity::from_str), - ) - .await - { - Ok(all_vulns) => { - let filtered: Vec<_> = all_vulns - .into_iter() - .filter(|v| v.cve.to_lowercase().contains(&cve_filter.to_lowercase())) - .take(limit as usize) - .collect(); - Ok(filtered) - } - Err(e) => Err(e), - } - } else { - vulnerability_client - .get_agent_vulnerabilities( - &agent_id, - Some(limit), - Some(offset), - params - .severity - .as_deref() - .and_then(VulnerabilitySeverity::from_str), - ) - .await - }; + let vulnerabilities = vulnerability_client + .get_agent_vulnerabilities( + &agent_id, + Some(limit), + Some(offset), + params + .severity + .as_deref() + .and_then(VulnerabilitySeverity::from_str), + ) + .await; match vulnerabilities { Ok(vulnerabilities) => { @@ -268,6 +245,7 @@ impl VulnerabilityTools { &self, params: GetCriticalVulnerabilitiesParams, ) -> Result { + let limit = params.limit.unwrap_or(300); let agent_id = match Self::format_agent_id(¶ms.agent_id) { Ok(formatted_id) => formatted_id, Err(err_msg) => { @@ -287,7 +265,7 @@ impl VulnerabilityTools { let mut vulnerability_client = self.vulnerability_client.lock().await; match vulnerability_client - .get_critical_vulnerabilities(&agent_id) + .get_critical_vulnerabilities(&agent_id, Some(limit)) .await { Ok(vulnerabilities) => { @@ -408,5 +386,4 @@ impl VulnerabilityTools { } } -impl ToolModule for VulnerabilityTools {} - +impl ToolModule for VulnerabilityTools {} \ No newline at end of file diff --git a/tests/run_tests.sh b/tests/run_tests.sh index c44de8a..a3de3ac 100755 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -37,47 +37,14 @@ echo "" echo "=== Running Integration Tests with Mock Wazuh ===" cargo test --test rmcp_integration_test -echo "" -echo "=== Manual MCP Server Testing ===" - -# Test the server with a simple MCP interaction -echo "Testing MCP server initialization..." - -# Create a temporary test script -cat > /tmp/test_mcp_server.sh << 'INNER_EOF' -#!/bin/bash - -# Start the server in background -WAZUH_HOST=mock.example.com RUST_LOG=error cargo run --bin mcp-server-wazuh & -SERVER_PID=$! - -# Give server time to start -sleep 1 - -# Test MCP initialization -echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}' | timeout 5s nc -l 0 2>/dev/null || { - # If nc doesn't work, try a different approach - echo "Testing server startup..." - sleep 2 -} - -# Clean up -kill $SERVER_PID 2>/dev/null || true -wait $SERVER_PID 2>/dev/null || true - -echo "Manual test completed" -INNER_EOF - -chmod +x /tmp/test_mcp_server.sh -/tmp/test_mcp_server.sh -rm /tmp/test_mcp_server.sh - echo "" echo "=== Testing Server Binary ===" echo "Verifying server binary can start and show help..." -# Test that the binary can start and show help -timeout 5s cargo run --bin mcp-server-wazuh -- --help || echo "Help command test completed" +# Test that the binary can start and show help. It should exit immediately. +cargo run --bin mcp-server-wazuh -- --help > /dev/null + +echo "Help command test completed" echo "" echo "=== All Tests Complete ==="