mirror of
https://github.com/gbrigandi/mcp-server-wazuh.git
synced 2025-12-26 14:37:44 -06:00
feat: Refactor tools and upgrade wazuh-client
This commit introduces a major refactoring of the tool implementation by splitting the tools into separate modules based on their domain (agents, alerts, rules, stats, vulnerabilities). This improves modularity and maintainability. Key changes: - Upgraded wazuh-client to version 0.1.7 to leverage the new builder pattern for client instantiation. - Refactored the main WazuhToolsServer to delegate tool calls to the new domain-specific tool modules. - Created a tools module with submodules for each domain, each containing the relevant tool implementations and parameter structs. - Updated the default limit for most tools from 100 to 300, while the vulnerability summary limit is set to 10,000 to ensure comprehensive scans. - Removed a problematic manual test from the test script that was causing it to hang.
This commit is contained in:
@@ -15,7 +15,7 @@ use wazuh_client::{AgentsClient, Port as WazuhPort, VulnerabilityClient};
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAgentsParams {
|
||||
#[schemars(description = "Maximum number of agents to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of agents to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(
|
||||
description = "Agent status filter (active, disconnected, pending, never_connected)"
|
||||
@@ -39,7 +39,7 @@ pub struct GetAgentProcessesParams {
|
||||
description = "Agent ID to get processes for (required, e.g., \"0\", \"1\", \"001\")"
|
||||
)]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of processes to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of processes to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Search string to filter processes by name or command (optional)")]
|
||||
pub search: Option<String>,
|
||||
@@ -51,7 +51,7 @@ pub struct GetAgentPortsParams {
|
||||
description = "Agent ID to get network ports for (required, e.g., \"001\", \"002\", \"003\")"
|
||||
)]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of ports to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of ports to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Protocol to filter by (e.g., \"tcp\", \"udp\")")]
|
||||
pub protocol: String,
|
||||
@@ -80,7 +80,7 @@ impl AgentTools {
|
||||
&self,
|
||||
params: GetAgentsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
@@ -277,7 +277,7 @@ impl AgentTools {
|
||||
return Self::error_result(err_msg);
|
||||
}
|
||||
};
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = 0;
|
||||
|
||||
tracing::info!(
|
||||
@@ -401,7 +401,7 @@ impl AgentTools {
|
||||
return Self::error_result(err_msg);
|
||||
}
|
||||
};
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = 0; // Default offset
|
||||
|
||||
tracing::info!(
|
||||
|
||||
@@ -15,7 +15,7 @@ use super::ToolModule;
|
||||
/// Parameters for getting alert summary
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAlertSummaryParams {
|
||||
#[schemars(description = "Maximum number of alerts to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of alerts to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
@@ -38,21 +38,19 @@ impl AlertTools {
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAlertSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(limit = %limit, "Retrieving Wazuh alert summary");
|
||||
|
||||
match self.indexer_client.get_alerts().await {
|
||||
match self.indexer_client.get_alerts(Some(limit)).await {
|
||||
Ok(raw_alerts) => {
|
||||
let alerts_to_process: Vec<_> = raw_alerts.into_iter().take(limit as usize).collect();
|
||||
|
||||
if alerts_to_process.is_empty() {
|
||||
if raw_alerts.is_empty() {
|
||||
tracing::info!("No Wazuh alerts found to process. Returning standard message.");
|
||||
return Self::not_found_result("Wazuh alerts");
|
||||
}
|
||||
|
||||
let num_alerts_to_process = alerts_to_process.len();
|
||||
let mcp_content_items: Vec<Content> = alerts_to_process
|
||||
let num_alerts_to_process = raw_alerts.len();
|
||||
let mcp_content_items: Vec<Content> = raw_alerts
|
||||
.into_iter()
|
||||
.map(|alert_value| {
|
||||
let source = alert_value.get("_source").unwrap_or(&alert_value);
|
||||
|
||||
@@ -26,9 +26,12 @@ pub trait ToolModule {
|
||||
}
|
||||
|
||||
fn not_found_result(resource: &str) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"No {} found matching the specified criteria.", resource
|
||||
))]))
|
||||
let message = if resource == "Wazuh alerts" {
|
||||
"No Wazuh alerts found.".to_string()
|
||||
} else {
|
||||
format!("No {} found matching the specified criteria.", resource)
|
||||
};
|
||||
Ok(CallToolResult::success(vec![Content::text(message)]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ use super::ToolModule;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetRulesSummaryParams {
|
||||
#[schemars(description = "Maximum number of rules to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of rules to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Rule level to filter by (optional)")]
|
||||
pub level: Option<u32>,
|
||||
@@ -43,7 +43,7 @@ impl RuleTools {
|
||||
&self,
|
||||
#[tool(aggr)] params: GetRulesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
|
||||
@@ -13,7 +13,7 @@ use wazuh_client::{ClusterClient, LogsClient};
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct SearchManagerLogsParams {
|
||||
#[schemars(description = "Maximum number of log entries to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of log entries to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Number of log entries to skip (default: 0)")]
|
||||
pub offset: Option<u32>,
|
||||
@@ -27,7 +27,7 @@ pub struct SearchManagerLogsParams {
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetManagerErrorLogsParams {
|
||||
#[schemars(description = "Maximum number of error log entries to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of error log entries to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ impl StatsTools {
|
||||
&self,
|
||||
params: SearchManagerLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
|
||||
tracing::info!(
|
||||
@@ -151,7 +151,7 @@ impl StatsTools {
|
||||
&self,
|
||||
params: GetManagerErrorLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(limit = %limit, "Retrieving Wazuh manager error logs");
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ use super::ToolModule;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetVulnerabilitiesSummaryParams {
|
||||
#[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 100)")]
|
||||
#[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 10000)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Agent ID to filter vulnerabilities by (required, e.g., \"0\", \"1\", \"001\")")]
|
||||
pub agent_id: String,
|
||||
@@ -29,6 +29,8 @@ pub struct GetVulnerabilitiesSummaryParams {
|
||||
pub struct GetCriticalVulnerabilitiesParams {
|
||||
#[schemars(description = "Agent ID to get critical vulnerabilities for (required, e.g., \"0\", \"1\", \"001\")")]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -67,7 +69,7 @@ impl VulnerabilityTools {
|
||||
&self,
|
||||
params: GetVulnerabilitiesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let limit = params.limit.unwrap_or(10000);
|
||||
let offset = 0; // Default offset, can be extended in future if needed
|
||||
|
||||
let agent_id = match Self::format_agent_id(¶ms.agent_id) {
|
||||
@@ -91,42 +93,17 @@ impl VulnerabilityTools {
|
||||
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
let vulnerabilities = if let Some(cve_filter) = ¶ms.cve {
|
||||
match vulnerability_client
|
||||
.get_agent_vulnerabilities(
|
||||
&agent_id,
|
||||
Some(1000), // Get more results to filter
|
||||
Some(offset),
|
||||
params
|
||||
.severity
|
||||
.as_deref()
|
||||
.and_then(VulnerabilitySeverity::from_str),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(all_vulns) => {
|
||||
let filtered: Vec<_> = all_vulns
|
||||
.into_iter()
|
||||
.filter(|v| v.cve.to_lowercase().contains(&cve_filter.to_lowercase()))
|
||||
.take(limit as usize)
|
||||
.collect();
|
||||
Ok(filtered)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
} else {
|
||||
vulnerability_client
|
||||
.get_agent_vulnerabilities(
|
||||
&agent_id,
|
||||
Some(limit),
|
||||
Some(offset),
|
||||
params
|
||||
.severity
|
||||
.as_deref()
|
||||
.and_then(VulnerabilitySeverity::from_str),
|
||||
)
|
||||
.await
|
||||
};
|
||||
let vulnerabilities = vulnerability_client
|
||||
.get_agent_vulnerabilities(
|
||||
&agent_id,
|
||||
Some(limit),
|
||||
Some(offset),
|
||||
params
|
||||
.severity
|
||||
.as_deref()
|
||||
.and_then(VulnerabilitySeverity::from_str),
|
||||
)
|
||||
.await;
|
||||
|
||||
match vulnerabilities {
|
||||
Ok(vulnerabilities) => {
|
||||
@@ -268,6 +245,7 @@ impl VulnerabilityTools {
|
||||
&self,
|
||||
params: GetCriticalVulnerabilitiesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let agent_id = match Self::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
@@ -287,7 +265,7 @@ impl VulnerabilityTools {
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
match vulnerability_client
|
||||
.get_critical_vulnerabilities(&agent_id)
|
||||
.get_critical_vulnerabilities(&agent_id, Some(limit))
|
||||
.await
|
||||
{
|
||||
Ok(vulnerabilities) => {
|
||||
@@ -408,5 +386,4 @@ impl VulnerabilityTools {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for VulnerabilityTools {}
|
||||
|
||||
impl ToolModule for VulnerabilityTools {}
|
||||
Reference in New Issue
Block a user