diff --git a/src/main.rs b/src/main.rs index 4883392..5069db9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,7 +49,7 @@ use std::env; use clap::Parser; use dotenv::dotenv; -use wazuh_client::{WazuhClientFactory, WazuhIndexerClient}; +use wazuh_client::{WazuhClientFactory, WazuhIndexerClient, RulesClient}; #[derive(Parser, Debug)] #[command(name = "mcp-server-wazuh")] @@ -65,11 +65,24 @@ struct GetAlertSummaryParams { limit: Option, } +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct GetRulesSummaryParams { + #[schemars(description = "Maximum number of rules to retrieve (default: 100)")] + limit: Option, + #[schemars(description = "Rule level to filter by (optional)")] + level: Option, + #[schemars(description = "Rule group to filter by (optional)")] + group: Option, + #[schemars(description = "Filename to filter by (optional)")] + filename: Option, +} + #[derive(Clone)] struct WazuhToolsServer { #[allow(dead_code)] // Kept for future expansion to other Wazuh clients wazuh_factory: Arc, wazuh_indexer_client: Arc, + wazuh_rules_client: Arc>, } #[tool(tool_box)] @@ -103,26 +116,26 @@ impl WazuhToolsServer { let protocol = env::var("WAZUH_TEST_PROTOCOL").unwrap_or_else(|_| "https".to_string()); tracing::debug!(?protocol, "Using Wazuh protocol for client from WAZUH_TEST_PROTOCOL or default"); - // Create the factory with both API and Indexer configurations let wazuh_factory = WazuhClientFactory::new( - wazuh_host.clone(), // API host - wazuh_api_port, // API port - wazuh_user.clone(), // API username - wazuh_pass.clone(), // API password - wazuh_host, // Indexer host (same as API host) - wazuh_port, // Indexer port (use WAZUH_PORT for backward compatibility) - wazuh_user, // Indexer username - wazuh_pass, // Indexer password + wazuh_host.clone(), + wazuh_api_port, + wazuh_user.clone(), + wazuh_pass.clone(), + wazuh_host, + wazuh_port, + wazuh_user, + wazuh_pass, verify_ssl, Some(protocol), ); - // Create the indexer client using the factory let wazuh_indexer_client = wazuh_factory.create_indexer_client(); + let wazuh_rules_client = wazuh_factory.create_rules_client(); Ok(Self { wazuh_factory: Arc::new(wazuh_factory), wazuh_indexer_client: Arc::new(wazuh_indexer_client), + wazuh_rules_client: Arc::new(tokio::sync::Mutex::new(wazuh_rules_client)), }) } @@ -201,6 +214,110 @@ impl WazuhToolsServer { } } } + + #[tool( + name = "get_wazuh_rules_summary", + description = "Retrieves a summary of Wazuh security rules. Returns formatted rule information including ID, level, description, and groups. Supports filtering by level, group, and filename." + )] + async fn get_wazuh_rules_summary( + &self, + #[tool(aggr)] params: GetRulesSummaryParams, + ) -> Result { + let limit = params.limit.unwrap_or(100); + + tracing::info!( + limit = %limit, + level = ?params.level, + group = ?params.group, + filename = ?params.filename, + "Retrieving Wazuh rules summary" + ); + + let mut rules_client = self.wazuh_rules_client.lock().await; + + match rules_client.get_rules( + Some(limit), + None, // offset + params.level, + params.group.as_deref(), + params.filename.as_deref(), + ).await { + Ok(rules) => { + if rules.is_empty() { + tracing::info!("No Wazuh rules found matching criteria. Returning standard message."); + return Ok(CallToolResult::success(vec![Content::text( + "No Wazuh rules found matching the specified criteria.", + )])); + } + + let num_rules = rules.len(); + let mcp_content_items: Vec = rules + .into_iter() + .map(|rule| { + let groups_str = rule.groups.join(", "); + + let compliance_info = { + let mut compliance = Vec::new(); + if let Some(gdpr) = &rule.gdpr { + if !gdpr.is_empty() { + compliance.push(format!("GDPR: {}", gdpr.join(", "))); + } + } + if let Some(hipaa) = &rule.hipaa { + if !hipaa.is_empty() { + compliance.push(format!("HIPAA: {}", hipaa.join(", "))); + } + } + if let Some(pci) = &rule.pci_dss { + if !pci.is_empty() { + compliance.push(format!("PCI DSS: {}", pci.join(", "))); + } + } + if let Some(nist) = &rule.nist_800_53 { + if !nist.is_empty() { + compliance.push(format!("NIST 800-53: {}", nist.join(", "))); + } + } + if compliance.is_empty() { + String::new() + } else { + format!("\nCompliance: {}", compliance.join(" | ")) + } + }; + + let severity = match rule.level { + 0..=3 => "Low", + 4..=7 => "Medium", + 8..=12 => "High", + 13..=15 => "Critical", + _ => "Unknown", + }; + + let formatted_text = format!( + "Rule ID: {}\nLevel: {} ({})\nDescription: {}\nGroups: {}\nFile: {}\nStatus: {}{}", + rule.id, + rule.level, + severity, + rule.description, + groups_str, + rule.filename, + rule.status, + compliance_info + ); + Content::text(formatted_text) + }) + .collect(); + + tracing::info!("Successfully processed {} rules into {} MCP content items", num_rules, mcp_content_items.len()); + Ok(CallToolResult::success(mcp_content_items)) + } + Err(e) => { + let err_msg = format!("Error retrieving rules from Wazuh: {}", e); + tracing::error!("{}", err_msg); + Ok(CallToolResult::error(vec![Content::text(err_msg)])) + } + } + } } #[tool(tool_box)] @@ -218,7 +335,9 @@ impl ServerHandler for WazuhToolsServer { "This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis.\n\ Available tools:\n\ - 'get_wazuh_alert_summary': Retrieves a summary of Wazuh security alerts. \ - Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100)." + Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100).\n\ + - 'get_wazuh_rules_summary': Retrieves a summary of Wazuh security rules. \ + Supports filtering by 'level', 'group', and 'filename' parameters, with 'limit' to control the number of rules returned (defaults to 100)." .to_string(), ), }