mirror of
https://github.com/gbrigandi/mcp-server-wazuh.git
synced 2025-07-13 15:14:48 -06:00
Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5f452bac9d | ||
![]() |
4e39ca6bba | ||
![]() |
6e7463382e | ||
![]() |
e160888906 | ||
![]() |
b9327e0c87 | ||
![]() |
2df4408f1a | ||
![]() |
77914b5097 | ||
![]() |
943fabddd3 | ||
![]() |
5254df7aa1 | ||
![]() |
cc6f0871ca | ||
![]() |
e95b3fb199 | ||
![]() |
168bb7b8f7 | ||
![]() |
d23eef6269 | ||
![]() |
47b33414cb | ||
![]() |
2fd14a0063 | ||
![]() |
d86a4e1136 | ||
![]() |
bfffdbfb52 | ||
![]() |
f9efb70f19 | ||
![]() |
13f93cc844 | ||
![]() |
87ac8a6695 | ||
![]() |
5574dcc43a | ||
![]() |
8b491fe84a | ||
![]() |
53c5f7417e | ||
![]() |
ccb3fce350 | ||
![]() |
2e401259ec |
46
.env.example
46
.env.example
@ -1,9 +1,39 @@
|
||||
# Wazuh API Configuration
|
||||
WAZUH_HOST=localhost
|
||||
WAZUH_PORT=55000
|
||||
WAZUH_USER=admin
|
||||
WAZUH_PASS=admin
|
||||
VERIFY_SSL=false
|
||||
# Wazuh MCP Server Environment Configuration Example
|
||||
#
|
||||
# Copy this file to .env and fill in your specific values.
|
||||
# Lines starting with # are comments.
|
||||
|
||||
# MCP Server Configuration
|
||||
MCP_SERVER_PORT=8000
|
||||
# Wazuh Manager API Configuration
|
||||
# Hostname or IP address of the Wazuh Manager API server.
|
||||
WAZUH_API_HOST=localhost
|
||||
# Port number for the Wazuh Manager API.
|
||||
WAZUH_API_PORT=55000
|
||||
# Username for Wazuh Manager API authentication.
|
||||
WAZUH_API_USERNAME=wazuh
|
||||
# Password for Wazuh Manager API authentication.
|
||||
WAZUH_API_PASSWORD=wazuh
|
||||
|
||||
# Wazuh Indexer API Configuration
|
||||
# Hostname or IP address of the Wazuh Indexer API server.
|
||||
WAZUH_INDEXER_HOST=localhost
|
||||
# Port number for the Wazuh Indexer API.
|
||||
WAZUH_INDEXER_PORT=9200
|
||||
# Username for Wazuh Indexer API authentication.
|
||||
WAZUH_INDEXER_USERNAME=admin
|
||||
# Password for Wazuh Indexer API authentication.
|
||||
WAZUH_INDEXER_PASSWORD=admin
|
||||
|
||||
# SSL Configuration for Wazuh Connections
|
||||
# Set to "true" to verify SSL certificates for Wazuh API and Indexer connections.
|
||||
# Set to "false" to disable SSL verification (not recommended for production).
|
||||
WAZUH_VERIFY_SSL=false
|
||||
|
||||
# Protocol for Wazuh Connections (Optional)
|
||||
# Overrides the default protocol used by the wazuh-client.
|
||||
# Typically "http" or "https". If not set, the client's default (usually https) will be used.
|
||||
# WAZUH_TEST_PROTOCOL=https
|
||||
|
||||
# Logging Configuration
|
||||
# Controls the log level for the application and its dependencies.
|
||||
# Examples: "info", "debug", "trace", "mcp_server_wazuh=debug,wazuh_client=info"
|
||||
RUST_LOG=info
|
||||
|
42
.github/workflows/release.yml
vendored
42
.github/workflows/release.yml
vendored
@ -5,6 +5,10 @@ on:
|
||||
tags:
|
||||
- 'v*' # Trigger on version tags like v0.1.0
|
||||
|
||||
permissions:
|
||||
contents: write # Needed to create releases
|
||||
packages: write # Needed to push to GitHub Container Registry
|
||||
|
||||
jobs:
|
||||
create_release:
|
||||
name: Create Release
|
||||
@ -73,3 +77,41 @@ jobs:
|
||||
asset_name: mcp-server-wazuh-${{ matrix.asset_name_suffix }}
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build_docker:
|
||||
name: Build and Push Docker Image
|
||||
needs: create_release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/gbrigandi/mcp-server-wazuh
|
||||
tags: |
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
|
30
Cargo.toml
30
Cargo.toml
@ -1,26 +1,29 @@
|
||||
[package]
|
||||
name = "mcp-server-wazuh"
|
||||
version = "0.1.0"
|
||||
version = "0.2.4"
|
||||
edition = "2021"
|
||||
description = "Wazuh SIEM MCP Server"
|
||||
authors = ["Gianluca Brigandi <gbrigand@gmail.com>"]
|
||||
license = "MIT"
|
||||
repository = "https://github.com/gbrigandi/mcp-server-wazuh"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.45", features = ["full"] }
|
||||
axum = "0.8"
|
||||
wazuh-client = "0.1.7"
|
||||
rmcp = { version = "0.1.5", features = ["server", "transport-io"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
anyhow = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
schemars = "0.8"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
dotenv = "0.15"
|
||||
thiserror = "2.0"
|
||||
jsonwebtoken = "9.3"
|
||||
tower-http = { version = "0.6", features = ["trace"] }
|
||||
async-trait = "0.1"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.5", features = ["derive", "env"] }
|
||||
chrono = "0.4.41"
|
||||
openssl-sys = { version = "0.9", features = ["vendored"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.7"
|
||||
@ -30,8 +33,7 @@ uuid = { version = "1.16", features = ["v4"] }
|
||||
once_cell = "1.21"
|
||||
async-trait = "0.1"
|
||||
regex = "1.11"
|
||||
|
||||
[[bin]]
|
||||
name = "mcp_client_cli"
|
||||
path = "tests/mcp_client_cli.rs"
|
||||
tokio-test = "0.4"
|
||||
serde_json = "1.0"
|
||||
tempfile = "3.0"
|
||||
|
||||
|
@ -4,10 +4,10 @@ WORKDIR /usr/src/app
|
||||
COPY . .
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y pkg-config libssl-dev && \
|
||||
apt-get install -y pkg-config libssl-dev build-essential perl make && \
|
||||
cargo build --release
|
||||
|
||||
FROM debian:bullseye-slim
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y ca-certificates && \
|
||||
|
524
README.md
524
README.md
@ -1,27 +1,235 @@
|
||||
# Wazuh MCP Server
|
||||
# Wazuh MCP Server - Talk to your SIEM
|
||||
|
||||
A Rust-based server designed to bridge the gap between a Wazuh Security Information and Event Management (SIEM) system and applications requiring contextual security data, specifically tailored for the Claude Desktop Integration using the Model Context Protocol (MCP).
|
||||
|
||||
## Overview
|
||||
|
||||
Modern AI assistants like Claude can benefit significantly from real-time context about the user's environment. For security operations, this means providing relevant security alerts and events. Wazuh is a popular open-source SIEM, but its API output isn't directly consumable by systems expecting MCP format.
|
||||
Modern AI assistants like Claude can benefit significantly from real-time context about the user's security environment. The Wazuh MCP Server bridges this gap by providing comprehensive access to Wazuh SIEM data through natural language interactions.
|
||||
|
||||
This server transforms complex Wazuh API responses into MCP-compatible format, enabling AI assistants to access:
|
||||
|
||||
- **Security Alerts & Events** from the Wazuh Indexer for threat detection and incident response
|
||||
- **Agent Management & Monitoring** including health status, system processes, and network ports
|
||||
- **Vulnerability Assessment** data for risk management and patch prioritization
|
||||
- **Security Rules & Configuration** for detection optimization and compliance validation
|
||||
- **System Statistics & Performance** metrics for operational monitoring and audit trails
|
||||
- **Log Analysis & Forensics** capabilities for incident investigation and compliance reporting
|
||||
- **Cluster Health & Management** for infrastructure reliability and availability requirements
|
||||
- **Compliance Monitoring & Gap Analysis** for regulatory frameworks like PCI-DSS, HIPAA, SOX, and GDPR
|
||||
|
||||
Rather than requiring manual API calls or complex queries, security teams can now ask natural language questions like "Show me critical vulnerabilities on web servers," "What processes are running on agent 001?" or "Are we meeting PCI-DSS logging requirements?" and receive structured, actionable data from their Wazuh deployment.
|
||||
|
||||
This approach is particularly valuable for compliance teams who need to quickly assess security posture, identify gaps in monitoring coverage, validate rule effectiveness, and generate evidence for audit requirements across distributed infrastructure.
|
||||
|
||||

|
||||
|
||||
## Example Use Cases
|
||||
|
||||
The Wazuh MCP Server, by bridging Wazuh's security data with MCP-compatible applications, unlocks several powerful use cases:
|
||||
The Wazuh MCP Server provides direct access to Wazuh security data through natural language interactions, enabling several practical use cases:
|
||||
|
||||
* **Delegated Alert Triage:** Automate alert categorization and prioritization via AI, focusing analyst attention on critical events.
|
||||
* **Enhanced Alert Correlation:** Enrich alerts by correlating with CVEs, OSINT, and other threat intelligence for deeper context and risk assessment.
|
||||
* **Dynamic Security Visualizations:** Generate on-demand reports and visualizations of Wazuh data to answer specific security questions.
|
||||
* **Multilingual Security Operations:** Query Wazuh data and receive insights in multiple languages for global team accessibility.
|
||||
* **Natural Language Data Interaction:** Query Wazuh data using natural language for intuitive access to security information.
|
||||
* **Contextual Augmentation for Other Tools:** Use Wazuh data as context to enrich other MCP-enabled tools and AI assistants.
|
||||
### Security Alert Analysis
|
||||
* **Alert Triage and Investigation:** Query recent security alerts with `get_wazuh_alert_summary` to quickly identify and prioritize threats requiring immediate attention.
|
||||
* **Alert Pattern Recognition:** Analyze alert trends and patterns to identify recurring security issues or potential attack campaigns.
|
||||
|
||||
### Vulnerability Management
|
||||
* **Agent Vulnerability Assessment:** Use `get_wazuh_vulnerability_summary` and `get_wazuh_critical_vulnerabilities` to assess security posture of specific agents and prioritize patching efforts.
|
||||
* **Risk-Based Vulnerability Prioritization:** Correlate vulnerability data with agent criticality and exposure to focus remediation efforts.
|
||||
|
||||
### System Monitoring and Forensics
|
||||
* **Process Analysis:** Investigate running processes on agents using `get_wazuh_agent_processes` for threat hunting and system analysis.
|
||||
* **Network Security Assessment:** Monitor open ports and network services with `get_wazuh_agent_ports` to identify potential attack vectors.
|
||||
* **Agent Health Monitoring:** Track agent status and connectivity using `get_wazuh_running_agents` to ensure comprehensive security coverage.
|
||||
|
||||
### Security Operations Intelligence
|
||||
* **Rule Effectiveness Analysis:** Review and analyze security detection rules with `get_wazuh_rules_summary` to optimize detection capabilities.
|
||||
* **Manager Performance Monitoring:** Track system performance and statistics using tools like `get_wazuh_weekly_stats`, `get_wazuh_remoted_stats`, and `get_wazuh_log_collector_stats`.
|
||||
* **Cluster Health Management:** Monitor Wazuh cluster status with `get_wazuh_cluster_health` and `get_wazuh_cluster_nodes` for operational reliability.
|
||||
|
||||
### Incident Response and Forensics
|
||||
* **Log Analysis:** Search and analyze manager logs using `search_wazuh_manager_logs` and `get_wazuh_manager_error_logs` for incident investigation.
|
||||
* **Agent-Specific Investigation:** Combine multiple tools to build comprehensive profiles of specific agents during security incidents.
|
||||
* **Natural Language Security Queries:** Ask complex security questions in natural language and receive structured data from multiple Wazuh components.
|
||||
|
||||
### Operational Efficiency
|
||||
* **Automated Reporting:** Generate security reports and summaries through conversational interfaces without manual API calls.
|
||||
* **Cross-Component Analysis:** Correlate data from both Wazuh Indexer (alerts) and Wazuh Manager (agents, rules, vulnerabilities) for comprehensive security insights.
|
||||
* **Multilingual Security Operations:** Access Wazuh data and receive insights in multiple languages for global security teams.
|
||||
|
||||
### Threat Intelligence Gathering and Response
|
||||
|
||||
For enhanced threat intelligence capabilities, the Wazuh MCP Server can be combined with the **[Cortex MCP Server](https://github.com/gbrigandi/mcp-server-cortex/)** to create a powerful security analysis ecosystem.
|
||||
|
||||
**Enhanced Capabilities with Cortex Integration:**
|
||||
* **Artifact Analysis:** Automatically analyze suspicious files, URLs, domains, and IP addresses found in Wazuh alerts using Cortex's 140+ analyzers
|
||||
* **IOC Enrichment:** Enrich indicators of compromise (IOCs) from Wazuh alerts with threat intelligence from multiple sources including VirusTotal, Shodan, MISP, and more
|
||||
* **Automated Threat Hunting:** Combine Wazuh's detection capabilities with Cortex's analysis engines to automatically investigate and classify threats
|
||||
* **Multi-Source Intelligence:** Leverage analyzers for reputation checks, malware analysis, domain analysis, and behavioral analysis
|
||||
* **Response Orchestration:** Use analysis results to inform automated response actions and alert prioritization
|
||||
|
||||
**Example Workflow:**
|
||||
1. Wazuh detects a suspicious file hash or network connection in an alert
|
||||
2. The AI assistant automatically queries the Cortex MCP Server to analyze the artifact using multiple analyzers
|
||||
3. Results from VirusTotal, hybrid analysis, domain reputation, and other sources are correlated
|
||||
4. The combined intelligence provides context for incident response decisions
|
||||
5. Findings can be used to update Wazuh rules or trigger additional monitoring
|
||||
|
||||
## Requirements
|
||||
|
||||
- An MCP (Model Context Protocol) compatible LLM client (e.g., Claude Desktop)
|
||||
- A running Wazuh server (v4.12 recommended) with the API enabled and accessible.
|
||||
- Network connectivity between this server and the Wazuh API (if API interaction is used).
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: Download Pre-built Binary (Recommended)
|
||||
|
||||
1. **Download the Binary:**
|
||||
* Go to the [Releases page](https://github.com/gbrigandi/mcp-server-wazuh/releases) of the `mcp-server-wazuh` GitHub repository.
|
||||
* Download the appropriate binary for your operating system (e.g., `mcp-server-wazuh-linux-amd64`, `mcp-server-wazuh-macos-amd64`, `mcp-server-wazuh-windows-amd64.exe`).
|
||||
* Make the downloaded binary executable (e.g., `chmod +x mcp-server-wazuh-linux-amd64`).
|
||||
* (Optional) Rename it to something simpler like `mcp-server-wazuh` and move it to a directory in your system's `PATH` for easier access.
|
||||
|
||||
### Option 2: Docker
|
||||
|
||||
1. **Pull the Docker Image:**
|
||||
```bash
|
||||
docker pull ghcr.io/gbrigandi/mcp-server-wazuh:latest
|
||||
```
|
||||
|
||||
### Option 3: Build from Source
|
||||
|
||||
1. **Prerequisites:**
|
||||
* Install Rust: [https://www.rust-lang.org/tools/install](https://www.rust-lang.org/tools/install)
|
||||
|
||||
2. **Build:**
|
||||
```bash
|
||||
git clone https://github.com/gbrigandi/mcp-server-wazuh.git
|
||||
cd mcp-server-wazuh
|
||||
cargo build --release
|
||||
```
|
||||
The binary will be available at `target/release/mcp-server-wazuh`.
|
||||
|
||||
### Configure Your LLM Client
|
||||
|
||||
The method for configuring your LLM client will vary depending on the client itself. For clients that support MCP (Model Context Protocol), you will typically need to point the client to the path of the `mcp-server-wazuh` executable.
|
||||
|
||||
**Example for Claude Desktop:**
|
||||
|
||||
Configure your `claude_desktop_config.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"wazuh": {
|
||||
"command": "/path/to/mcp-server-wazuh",
|
||||
"args": [],
|
||||
"env": {
|
||||
"WAZUH_API_HOST": "your_wazuh_manager_api_host",
|
||||
"WAZUH_API_PORT": "55000",
|
||||
"WAZUH_API_USERNAME": "your_wazuh_api_user",
|
||||
"WAZUH_API_PASSWORD": "your_wazuh_api_password",
|
||||
"WAZUH_INDEXER_HOST": "your_wazuh_indexer_host",
|
||||
"WAZUH_INDEXER_PORT": "9200",
|
||||
"WAZUH_INDEXER_USERNAME": "your_wazuh_indexer_user",
|
||||
"WAZUH_INDEXER_PASSWORD": "your_wazuh_indexer_password",
|
||||
"WAZUH_VERIFY_SSL": "false",
|
||||
"WAZUH_TEST_PROTOCOL": "https",
|
||||
"RUST_LOG": "info"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Replace `/path/to/mcp-server-wazuh` with the actual path to your binary and configure the environment variables as detailed in the [Configuration](#configuration) section.
|
||||
|
||||
Once configured, your LLM client should be able to launch and communicate with the `mcp-server-wazuh` to access Wazuh security data.
|
||||
|
||||
If using Docker, create a `.env` file with your Wazuh configuration:
|
||||
|
||||
```bash
|
||||
WAZUH_API_HOST=your_wazuh_manager_api_host
|
||||
WAZUH_API_PORT=55000
|
||||
WAZUH_API_USERNAME=your_wazuh_api_user
|
||||
WAZUH_API_PASSWORD=your_wazuh_api_password
|
||||
WAZUH_INDEXER_HOST=your_wazuh_indexer_host
|
||||
WAZUH_INDEXER_PORT=9200
|
||||
WAZUH_INDEXER_USERNAME=your_wazuh_indexer_user
|
||||
WAZUH_INDEXER_PASSWORD=your_wazuh_indexer_password
|
||||
WAZUH_VERIFY_SSL=false
|
||||
WAZUH_TEST_PROTOCOL=https
|
||||
RUST_LOG=info
|
||||
```
|
||||
|
||||
Configure your `claude_desktop_config.json` file:
|
||||
|
||||
```
|
||||
{
|
||||
"mcpServers": {
|
||||
"wazuh": {
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"run", "--rm", "-i",
|
||||
"--env-file", "/path/to/your/.env",
|
||||
"ghcr.io/gbrigandi/mcp-server-wazuh:latest"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Configuration is managed through environment variables. A `.env` file can be placed in the project root for local development.
|
||||
|
||||
| Variable | Description | Default | Required |
|
||||
| ------------------------ | ------------------------------------------------------------------------------ | ----------- | -------- |
|
||||
| `WAZUH_API_HOST` | Hostname or IP address of the Wazuh Manager API server. | `localhost` | Yes |
|
||||
| `WAZUH_API_PORT` | Port number for the Wazuh Manager API. | `55000` | Yes |
|
||||
| `WAZUH_API_USERNAME` | Username for Wazuh Manager API authentication. | `wazuh` | Yes |
|
||||
| `WAZUH_API_PASSWORD` | Password for Wazuh Manager API authentication. | `wazuh` | Yes |
|
||||
| `WAZUH_INDEXER_HOST` | Hostname or IP address of the Wazuh Indexer API server. | `localhost` | Yes |
|
||||
| `WAZUH_INDEXER_PORT` | Port number for the Wazuh Indexer API. | `9200` | Yes |
|
||||
| `WAZUH_INDEXER_USERNAME` | Username for Wazuh Indexer API authentication. | `admin` | Yes |
|
||||
| `WAZUH_INDEXER_PASSWORD` | Password for Wazuh Indexer API authentication. | `admin` | Yes |
|
||||
| `WAZUH_VERIFY_SSL` | Set to `true` to verify SSL certificates for Wazuh API and Indexer connections. | `false` | No |
|
||||
| `WAZUH_TEST_PROTOCOL` | Protocol for Wazuh connections (e.g., "http", "https"). Overrides client default. | `https` | No |
|
||||
| `RUST_LOG` | Log level (e.g., `info`, `debug`, `trace`). | `info` | No |
|
||||
|
||||
**Note on `WAZUH_VERIFY_SSL`:** For production environments, it is strongly recommended to set `WAZUH_VERIFY_SSL=true` and ensure proper certificate validation for both Wazuh Manager API and Wazuh Indexer connections. Setting it to `false` disables certificate checks, which is insecure.
|
||||
The "Required: Yes" indicates that these variables are essential for the server to connect to the respective Wazuh components. While defaults are provided, they are unlikely to match a production or non-local setup.
|
||||
|
||||
## Building
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Install Rust: [https://www.rust-lang.org/tools/install](https://www.rust-lang.org/tools/install)
|
||||
- Install Docker and Docker Compose (optional, for containerized deployment): [https://docs.docker.com/get-docker/](https://docs.docker.com/get-docker/)
|
||||
|
||||
### Local Development
|
||||
|
||||
1. **Clone the repository:**
|
||||
```bash
|
||||
git clone https://github.com/gbrigandi/mcp-server-wazuh.git
|
||||
cd mcp-server-wazuh
|
||||
```
|
||||
2. **Configure (if using Wazuh API):**
|
||||
- Copy the example environment file: `cp .env.example .env`
|
||||
- Edit the `.env` file with your specific Wazuh API details (e.g. `WAZUH_API_HOST`, `WAZUH_API_PORT`).
|
||||
3. **Build:**
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
4. **Run:**
|
||||
```bash
|
||||
cargo run
|
||||
# Or use the run script (which might set up stdio mode):
|
||||
# ./run.sh
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The server primarily facilitates communication between an application (e.g., an IDE extension or CLI tool) and the Wazuh MCP Server itself via stdio. The server can then interact with the Wazuh API as needed.
|
||||
The server is built using the [rmcp](https://crates.io/crates/rmcp) framework and facilitates communication between MCP clients (e.g., Claude Desktop, IDE extensions) and the Wazuh MCP Server via stdio transport. The server interacts with the Wazuh Indexer and Wazuh Manager APIs to fetch security alerts and other data.
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
@ -62,86 +270,6 @@ sequenceDiagram
|
||||
|
||||
This stdio interaction allows for tight integration with local development tools or other applications that can manage child processes. An optional HTTP endpoint (`/mcp`) may also be available for clients that prefer polling.
|
||||
|
||||
## Features
|
||||
|
||||
- **Stdio Communication:** Interacts with client applications via `stdin` and `stdout` using the Model Context Protocol (MCP), suitable for integration with IDEs or CLI tools.
|
||||
- **Wazuh API Integration:** Connects to the Wazuh API to fetch security data. Handles authentication using configured credentials.
|
||||
- **Alert Retrieval:** Fetches alerts from the Wazuh API (e.g., can be configured to retrieve recent alerts).
|
||||
- **MCP Transformation:** Converts Wazuh alert JSON objects into MCP v1.0 compliant JSON messages. This includes:
|
||||
- Mapping Wazuh `rule.level` to MCP `severity` (e.g., 0-3 -> "low", 8-11 -> "high").
|
||||
- Extracting `rule.description`, `id`, `timestamp`, `agent` details, and the `data` payload.
|
||||
- Taking the first group from `rule.groups` as the MCP `category`.
|
||||
- Handling potential differences in Wazuh response structure (e.g., presence or absence of `_source` nesting).
|
||||
- Providing default values (e.g., "unknown_severity", "unknown_category", current time for invalid timestamps).
|
||||
- **Optional HTTP Server:** Can expose endpoints using the Axum web framework.
|
||||
- `/mcp`: Serves the transformed MCP messages.
|
||||
- `/health`: Provides a simple health check.
|
||||
- **Configuration:** Easily configurable via environment variables or a `.env` file.
|
||||
- **Containerization:** Includes a `Dockerfile` and `docker-compose.yml` for easy deployment.
|
||||
- **Logging:** Uses the `tracing` library for application logging (configurable via `RUST_LOG`).
|
||||
|
||||
## Requirements
|
||||
|
||||
- Rust (latest stable recommended, see `Cargo.toml` for specific dependencies)
|
||||
- A running Wazuh server (v4.x recommended) with the API enabled and accessible.
|
||||
- Network connectivity between this server and the Wazuh API (if API interaction is used).
|
||||
|
||||
## Configuration
|
||||
|
||||
Configuration is managed through environment variables. A `.env` file can be placed in the project root for local development.
|
||||
|
||||
| Variable | Description | Default | Required (for API) |
|
||||
| ----------------- | ------------------------------------------------- | ----------- | ------------------ |
|
||||
| `WAZUH_HOST` | Hostname or IP address of the Wazuh API server. | `localhost` | Yes |
|
||||
| `WAZUH_PORT` | Port number for the Wazuh API. | `9200` | Yes |
|
||||
| `WAZUH_USER` | Username for Wazuh API authentication. | `admin` | Yes |
|
||||
| `WAZUH_PASS` | Password for Wazuh API authentication. | `admin` | Yes |
|
||||
| `VERIFY_SSL` | Set to `true` to verify the Wazuh API's SSL cert. | `false` | No |
|
||||
| `MCP_SERVER_PORT` | Port for this MCP server to listen on (if HTTP enabled). | `8000` | No |
|
||||
| `RUST_LOG` | Log level (e.g., `info`, `debug`, `trace`). | `info` | No |
|
||||
|
||||
**Note on `VERIFY_SSL`:** For production environments using the Wazuh API, it is strongly recommended to set `VERIFY_SSL=true` and ensure proper certificate validation. Setting it to `false` disables certificate checks, which is insecure.
|
||||
|
||||
## Building and Running
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Install Rust: [https://www.rust-lang.org/tools/install](https://www.rust-lang.org/tools/install)
|
||||
- Install Docker and Docker Compose (optional, for containerized deployment): [https://docs.docker.com/get-docker/](https://docs.docker.com/get-docker/)
|
||||
|
||||
### Local Development
|
||||
|
||||
1. **Clone the repository:**
|
||||
```bash
|
||||
git clone https://github.com/yourusername/mcp-server-wazuh.git
|
||||
cd mcp-server-wazuh
|
||||
```
|
||||
2. **Configure (if using Wazuh API):**
|
||||
- Copy the example environment file: `cp .env.example .env`
|
||||
- Edit the `.env` file with your specific Wazuh API details (`WAZUH_HOST`, `WAZUH_PORT`, `WAZUH_USER`, `WAZUH_PASS`).
|
||||
3. **Build:**
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
4. **Run:**
|
||||
```bash
|
||||
cargo run
|
||||
# Or use the run script (which might set up stdio mode):
|
||||
# ./run.sh
|
||||
```
|
||||
If the HTTP server is enabled, it will start listening on the port specified by `MCP_SERVER_PORT` (default 8000). Otherwise, it will operate in stdio mode.
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
1. **Clone the repository** (if not already done).
|
||||
2. **Configure:** Ensure you have a `.env` file with your Wazuh credentials in the project root if using the API, or set the environment variables directly in the `docker-compose.yml` or your deployment environment.
|
||||
3. **Build and Run:**
|
||||
```bash
|
||||
docker-compose up --build -d
|
||||
```
|
||||
This will build the Docker image and start the container in detached mode.
|
||||
|
||||
## Stdio Mode Operation
|
||||
|
||||
The server communicates via `stdin` and `stdout` using JSON-RPC 2.0 messages, adhering to the Model Context Protocol (MCP).
|
||||
|
||||
@ -170,49 +298,22 @@ Example interaction flow:
|
||||
```
|
||||
|
||||
3. **Server sends `initialize` response to client via `stdout`:**
|
||||
(Capabilities shown are illustrative based on logs; actual capabilities might vary.)
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 0,
|
||||
"id": 1,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"supported": true,
|
||||
"definitions": [
|
||||
{
|
||||
"name": "wazuhAlerts",
|
||||
"description": "Retrieves the latest security alerts from the Wazuh SIEM.",
|
||||
"inputSchema": { "type": "object", "properties": {} },
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"alerts": {
|
||||
"type": "array",
|
||||
"description": "A list of simplified alert objects.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "description": "The unique identifier of the alert." },
|
||||
"description": { "type": "string", "description": "The description of the rule that triggered the alert." }
|
||||
},
|
||||
"required": ["id", "description"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["alerts"]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"resources": { "supported": true },
|
||||
"prompts": { "supported": true }
|
||||
"prompts": {},
|
||||
"resources": {},
|
||||
"tools": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "Wazuh MCP Server",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
"name": "rmcp",
|
||||
"version": "0.1.5"
|
||||
},
|
||||
"instructions": "This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis.\nAvailable tools:\n- 'get_wazuh_alert_summary': Retrieves a summary of Wazuh security alerts. Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100)."
|
||||
}
|
||||
}
|
||||
```
|
||||
@ -240,30 +341,24 @@ Example interaction flow:
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"id": 2,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "wazuhAlerts",
|
||||
"description": "Retrieves the latest security alerts from the Wazuh SIEM.",
|
||||
"inputSchema": { "type": "object", "properties": {} },
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"description": "Retrieves a summary of Wazuh security alerts. Returns formatted alert information including ID, timestamp, and description.",
|
||||
"inputSchema": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"alerts": {
|
||||
"type": "array",
|
||||
"description": "A list of simplified alert objects.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "description": "The unique identifier of the alert." },
|
||||
"description": { "type": "string", "description": "The description of the rule that triggered the alert." }
|
||||
},
|
||||
"required": ["id", "description"]
|
||||
}
|
||||
"limit": {
|
||||
"description": "Maximum number of alerts to retrieve (default: 100)",
|
||||
"format": "uint32",
|
||||
"minimum": 0.0,
|
||||
"type": ["integer", "null"]
|
||||
}
|
||||
},
|
||||
"required": ["alerts"]
|
||||
"title": "GetAlertSummaryParams",
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -271,126 +366,77 @@ Example interaction flow:
|
||||
}
|
||||
```
|
||||
|
||||
7. **Client calls the `wazuhAlerts` tool by sending `tools/call` to server's `stdin`:**
|
||||
7. **Client calls the `get_wazuh_alert_summary` tool by sending `tools/call` to server's `stdin`:**
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "wazuhAlerts",
|
||||
"arguments": {}
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {
|
||||
"limit": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
8. **Server receives on `stdin`, processes the `wazuhAlerts` call (which involves querying the Wazuh API and transforming the data as described elsewhere in this README).**
|
||||
8. **Server receives on `stdin`, processes the `get_wazuh_alert_summary` call (which involves querying the Wazuh Indexer API and transforming the data).**
|
||||
|
||||
9. **Server sends `tools/call` response with transformed alerts to client via `stdout`:**
|
||||
(Alert content is illustrative and simplified.)
|
||||
9. **Server sends `tools/call` response with formatted alerts to client via `stdout`:**
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"id": 3,
|
||||
"result": {
|
||||
"alerts": [
|
||||
"content": [
|
||||
{
|
||||
"id": "1747091815.1212763",
|
||||
"description": "Attached USB Storage"
|
||||
"type": "text",
|
||||
"text": "Alert ID: 1747091815.1212763\nTime: 2024-01-15T10:30:45.123Z\nAgent: web-server-01\nLevel: 7\nDescription: Attached USB Storage"
|
||||
},
|
||||
{
|
||||
"id": "1747066333.1207112",
|
||||
"description": "New dpkg (Debian Package) installed."
|
||||
"type": "text",
|
||||
"text": "Alert ID: 1747066333.1207112\nTime: 2024-01-15T10:25:12.456Z\nAgent: database-server\nLevel: 5\nDescription: New dpkg (Debian Package) installed."
|
||||
}
|
||||
// ... other simplified alerts based on the tool's outputSchema
|
||||
]
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Running the All-in-One Demo (Wazuh + MCP Server)
|
||||
|
||||
For a complete local demo environment that includes Wazuh (Indexer, Manager, Dashboard) and the Wazuh MCP Server pre-configured to connect to it (for HTTP mode testing), you can use the `docker-compose.all-in-one.yml` file.
|
||||
|
||||
This setup is ideal for testing the end-to-end flow from Wazuh alerts to MCP messages via the HTTP interface.
|
||||
|
||||
**1. Launch the Environment:**
|
||||
|
||||
Navigate to the project root directory in your terminal and run:
|
||||
|
||||
```bash
|
||||
docker-compose -f docker-compose.all-in-one.yml up -d
|
||||
```
|
||||
|
||||
This command will:
|
||||
- Download the necessary Wazuh and OpenSearch images (if not already present).
|
||||
- Start the Wazuh Indexer, Wazuh Manager, and Wazuh Dashboard services.
|
||||
- Build and start the Wazuh MCP Server (in HTTP mode).
|
||||
- All services are configured to communicate with each other on an internal Docker network.
|
||||
|
||||
**2. Accessing Services:**
|
||||
|
||||
* **Wazuh Dashboard:**
|
||||
* URL: `https://localhost:8443` (Note: Uses HTTPS with a self-signed certificate, so your browser will likely show a warning).
|
||||
* Default Username: `admin`
|
||||
* Default Password: `AdminPassword123!` (This is set by `WAZUH_INITIAL_PASSWORD` in the `wazuh-indexer` service).
|
||||
|
||||
* **Wazuh MCP Server (HTTP Mode):**
|
||||
* The MCP server will be running and accessible on port `8000` by default (or the port specified by `MCP_SERVER_PORT` if you've set it as an environment variable on your host machine before running docker-compose).
|
||||
* Example MCP endpoint: `http://localhost:8000/mcp`
|
||||
* Example Health endpoint: `http://localhost:8000/health`
|
||||
* **Configuration:** The `mcp-server` service within `docker-compose.all-in-one.yml` is already configured with the necessary environment variables to connect to the `wazuh-manager` service:
|
||||
* `WAZUH_HOST=wazuh-manager`
|
||||
* `WAZUH_PORT=55000`
|
||||
* `WAZUH_USER=wazuh_user_demo`
|
||||
* `WAZUH_PASS=wazuh_password_demo`
|
||||
* `VERIFY_SSL=false`
|
||||
You do not need to set these in a separate `.env` file when using this all-in-one compose file, as they are defined directly in the service's environment.
|
||||
|
||||
**3. Stopping the Environment:**
|
||||
|
||||
To stop all services, run:
|
||||
|
||||
```bash
|
||||
docker-compose -f docker-compose.all-in-one.yml down
|
||||
```
|
||||
|
||||
To stop and remove volumes (deleting Wazuh data):
|
||||
|
||||
```bash
|
||||
docker-compose -f docker-compose.all-in-one.yml down -v
|
||||
```
|
||||
|
||||
This approach simplifies setup by bundling all necessary components and their configurations for HTTP mode testing.
|
||||
|
||||
## Claude Desktop Configuration
|
||||
|
||||
To integrate this server with the Claude Desktop application, you need to configure it in your `claude_desktop_config.json` file. Add an entry for the Wazuh server under `mcpServers` like the example below:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"wazuh": {
|
||||
"command": "/full/path/to/your/mcp-server-wazuh/target/release/mcp-server-wazuh",
|
||||
"args": [],
|
||||
"env": {
|
||||
"WAZUH_HOST": "wazuh.example.com",
|
||||
"WAZUH_PASS": "aVeryS3cureP@ssw0rd",
|
||||
"WAZUH_PORT": "9200",
|
||||
"RUST_LOG": "info,mcp_server_wazuh=debug"
|
||||
**Or, if no alerts are found:**
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "No Wazuh alerts found."
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
||||
**Configuration Notes:**
|
||||
|
||||
* **`command`**: The absolute path to your compiled `mcp-server-wazuh` executable (e.g., typically found in `target/release/mcp-server-wazuh` after a release build).
|
||||
* **`args`**: An array of arguments to pass to the command, if any.
|
||||
* **`env.WAZUH_HOST`**: The hostname or IP address of your Wazuh Indexer or API endpoint.
|
||||
* **`env.WAZUH_PASS`**: The password for authenticating with the Wazuh service.
|
||||
* **`env.WAZUH_PORT`**: The port number for the Wazuh service. Common ports are `9200` for direct Indexer access or `55000` for the Wazuh API. Adjust this according to your specific Wazuh setup and how this server is configured to connect.
|
||||
* **`env.RUST_LOG`**: Optional. Sets the logging level for the server. Example: `info,mcp_server_wazuh=debug` provides general info logging and debug level for this specific crate.
|
||||
**Or, if there's an error connecting to Wazuh:**
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error retrieving alerts from Wazuh: HTTP request error: connection refused"
|
||||
}
|
||||
],
|
||||
"isError": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Development & Testing
|
||||
|
||||
|
6
glama.json
Normal file
6
glama.json
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"$schema": "https://glama.ai/mcp/schemas/server.json",
|
||||
"maintainers": [
|
||||
"gbrigandi"
|
||||
]
|
||||
}
|
BIN
media/.DS_Store
vendored
Normal file
BIN
media/.DS_Store
vendored
Normal file
Binary file not shown.
10
run.sh
10
run.sh
@ -1,10 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ ! -f .env ]; then
|
||||
echo "No .env file found. Creating from .env.example..."
|
||||
cp .env.example .env
|
||||
echo "Please edit .env file with your configuration."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cargo run
|
@ -1,135 +0,0 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::logging_utils::{log_mcp_request, log_mcp_response};
|
||||
use crate::AppState; // Added
|
||||
|
||||
pub fn create_http_router(app_state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.route("/mcp", get(get_mcp_data))
|
||||
.route("/mcp", post(post_mcp_data))
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
async fn health_check() -> impl IntoResponse {
|
||||
Json(json!({
|
||||
"status": "ok",
|
||||
"service": "wazuh-mcp-server",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_mcp_data(
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<Value>>, ApiError> {
|
||||
info!("Handling GET /mcp request");
|
||||
|
||||
let wazuh_client = app_state.wazuh_client.lock().await;
|
||||
|
||||
match wazuh_client.get_alerts().await {
|
||||
Ok(alerts) => {
|
||||
// Transform Wazuh alerts to MCP messages
|
||||
let mcp_messages = alerts
|
||||
.iter()
|
||||
.map(|alert| {
|
||||
json!({
|
||||
"protocol_version": "1.0",
|
||||
"source": "Wazuh",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
"event_type": "alert",
|
||||
"context": alert,
|
||||
"metadata": {
|
||||
"integration": "Wazuh-MCP"
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(Json(mcp_messages))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error getting alerts from Wazuh: {}", e);
|
||||
Err(ApiError::InternalServerError(format!(
|
||||
"Failed to get alerts from Wazuh: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn post_mcp_data(
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<Json<Vec<Value>>, ApiError> {
|
||||
info!("Handling POST /mcp request with payload");
|
||||
debug!("Payload: {:?}", payload);
|
||||
|
||||
// Log the incoming payload
|
||||
let request_str = serde_json::to_string(&payload).unwrap_or_else(|e| {
|
||||
error!(
|
||||
"Failed to serialize POST request payload for logging: {}",
|
||||
e
|
||||
);
|
||||
format!(
|
||||
"{{\"error\":\"Failed to serialize request payload: {}\"}}",
|
||||
e
|
||||
)
|
||||
});
|
||||
log_mcp_request(&request_str);
|
||||
|
||||
let result = get_mcp_data(State(app_state)).await;
|
||||
|
||||
let response_str = match &result {
|
||||
Ok(json_response) => serde_json::to_string(&json_response.0).unwrap_or_else(|e| {
|
||||
error!("Failed to serialize POST response for logging: {}", e);
|
||||
format!("{{\"error\":\"Failed to serialize response: {}\"}}", e)
|
||||
}),
|
||||
Err(api_error) => {
|
||||
let error_json_surrogate = json!({
|
||||
"error": format!("{:?}", api_error) // Or a more structured error
|
||||
});
|
||||
serde_json::to_string(&error_json_surrogate).unwrap_or_else(|e| {
|
||||
error!("Failed to serialize POST error response for logging: {}", e);
|
||||
format!(
|
||||
"{{\"error\":\"Failed to serialize error response: {}\"}}",
|
||||
e
|
||||
)
|
||||
})
|
||||
}
|
||||
};
|
||||
log_mcp_response(&response_str);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
BadRequest(String),
|
||||
NotFound(String),
|
||||
InternalServerError(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg),
|
||||
ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
|
||||
ApiError::InternalServerError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
"error": error_message
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
13
src/lib.rs
13
src/lib.rs
@ -1,13 +1,2 @@
|
||||
use crate::wazuh::client::WazuhIndexerClient;
|
||||
use tokio::sync::Mutex;
|
||||
pub mod tools;
|
||||
|
||||
pub mod http_service;
|
||||
pub mod logging_utils;
|
||||
pub mod mcp;
|
||||
pub mod stdio_service;
|
||||
pub mod wazuh;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
pub wazuh_client: Mutex<WazuhIndexerClient>,
|
||||
}
|
||||
|
@ -1,27 +0,0 @@
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use tracing::error;
|
||||
|
||||
const REQUEST_LOG_FILE: &str = "mcp_requests.log";
|
||||
const RESPONSE_LOG_FILE: &str = "mcp_responses.log";
|
||||
|
||||
fn log_to_file(filename: &str, message: &str) {
|
||||
match OpenOptions::new().create(true).append(true).open(filename) {
|
||||
Ok(mut file) => {
|
||||
if let Err(e) = writeln!(file, "{}", message) {
|
||||
error!("Failed to write to {}: {}", filename, e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to open {} for appending: {}", filename, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_mcp_request(request_str: &str) {
|
||||
log_to_file(REQUEST_LOG_FILE, request_str);
|
||||
}
|
||||
|
||||
pub fn log_mcp_response(response_str: &str) {
|
||||
log_to_file(RESPONSE_LOG_FILE, response_str);
|
||||
}
|
601
src/main.rs
601
src/main.rs
@ -1,187 +1,452 @@
|
||||
//
|
||||
// Purpose:
|
||||
//
|
||||
// This Rust application implements an MCP (Model Context Protocol) server that acts as a
|
||||
// bridge to a Wazuh instance. It exposes various Wazuh functionalities as tools that can
|
||||
// be invoked by MCP clients (e.g., AI models, automation scripts).
|
||||
//
|
||||
// Architecture:
|
||||
// The application follows a modular design where the main MCP server delegates to
|
||||
// domain-specific tool modules, promoting separation of concerns and maintainability.
|
||||
//
|
||||
// Structure:
|
||||
// - `main()`: Entry point of the application. Initializes logging (tracing),
|
||||
// sets up the `WazuhToolsServer`, and starts the MCP server using stdio transport.
|
||||
//
|
||||
// - `WazuhToolsServer`: The core orchestrator struct that implements the `rmcp::ServerHandler` trait
|
||||
// and the `#[tool(tool_box)]` attribute. It acts as a facade that delegates tool calls to
|
||||
// specialized domain modules:
|
||||
// - Holds instances of domain-specific tool modules (AgentTools, AlertTools, RuleTools, etc.)
|
||||
// - Its methods, decorated with `#[tool(...)]`, define the MCP tool interface and delegate
|
||||
// to the appropriate domain module for actual implementation
|
||||
// - Manages the lifecycle and configuration of Wazuh client connections
|
||||
//
|
||||
// - Domain-Specific Tool Modules (in `tools/` package):
|
||||
// - `AlertTools` (`tools/alerts.rs`): Handles alert-related operations via Wazuh Indexer
|
||||
// - `RuleTools` (`tools/rules.rs`): Manages security rule queries via Wazuh Manager API
|
||||
// - `VulnerabilityTools` (`tools/vulnerabilities.rs`): Processes vulnerability data via Wazuh Manager API
|
||||
// - `AgentTools` (`tools/agents.rs`): Handles agent management and system information queries
|
||||
// - `StatsTools` (`tools/stats.rs`): Provides logging, statistics, and cluster health monitoring
|
||||
// Each module encapsulates:
|
||||
// - Domain-specific business logic and data formatting
|
||||
// - Parameter validation and error handling
|
||||
// - Client interaction patterns for their respective Wazuh APIs
|
||||
// - Rich output formatting with structured text and emojis
|
||||
//
|
||||
// - Tool Parameter Structs (e.g., `GetAlertSummaryParams`):
|
||||
// - These structs define the expected input parameters for each tool.
|
||||
// - They use `serde::Deserialize` for parsing input and `schemars::JsonSchema`
|
||||
// for generating a schema that MCP clients can use to understand how to call the tools.
|
||||
// - Located within their respective domain modules for better organization
|
||||
//
|
||||
// - `wazuh_client` crate:
|
||||
// - This external crate is used to interact with both the Wazuh Manager API and the Wazuh Indexer API.
|
||||
// - `WazuhClientFactory` is used to create specific clients (e.g., `WazuhIndexerClient`, `RulesClient`, `AgentsClient`, `LogsClient`, `ClusterClient`, `VulnerabilityClient`).
|
||||
// - Clients are wrapped in `Arc<Mutex<>>` for thread-safe access across async operations
|
||||
//
|
||||
// Workflow:
|
||||
// 1. Server starts and listens for MCP requests on stdio
|
||||
// 2. MCP client sends a `call_tool` request to `WazuhToolsServer`
|
||||
// 3. `WazuhToolsServer` routes the request to the appropriate domain-specific tool module
|
||||
// 4. The domain module validates parameters, interacts with the relevant Wazuh client, and formats results
|
||||
// 5. The result (success with formatted data or error) is packaged into a `CallToolResult`
|
||||
// and sent back to the MCP client via the main server
|
||||
//
|
||||
// Exposed Tools:
|
||||
// The server exposes a set of tools categorized by the Wazuh component they interact with:
|
||||
//
|
||||
// Alert Management (via AlertTools):
|
||||
// - `get_wazuh_alert_summary`: Retrieves a summary of security alerts from the Wazuh Indexer.
|
||||
//
|
||||
// Rule Management (via RuleTools):
|
||||
// - `get_wazuh_rules_summary`: Fetches security rules defined in the Wazuh Manager.
|
||||
//
|
||||
// Vulnerability Management (via VulnerabilityTools):
|
||||
// - `get_wazuh_vulnerability_summary`: Gets vulnerability scan results for a specific agent from the Wazuh Manager.
|
||||
// - `get_wazuh_critical_vulnerabilities`: Retrieves critical vulnerabilities for an agent.
|
||||
//
|
||||
// Agent Management (via AgentTools):
|
||||
// - `get_wazuh_agents`: Lists active and inactive agents connected to the Wazuh Manager.
|
||||
// - `get_wazuh_agent_processes`: Retrieves running processes on a specific agent (via Syscollector).
|
||||
// - `get_wazuh_agent_ports`: Lists open network ports on a specific agent (via Syscollector).
|
||||
//
|
||||
// Statistics and Monitoring (via StatsTools):
|
||||
// - `search_wazuh_manager_logs`: Searches logs generated by the Wazuh Manager.
|
||||
// - `get_wazuh_manager_error_logs`: Retrieves error-specific logs from the Wazuh Manager.
|
||||
// - `get_wazuh_log_collector_stats`: Gets log collection statistics for an agent.
|
||||
// - `get_wazuh_remoted_stats`: Fetches statistics from the Wazuh Manager's remoted daemon.
|
||||
// - `get_wazuh_weekly_stats`: Retrieves aggregated weekly statistics from the Wazuh Manager.
|
||||
// - `get_wazuh_cluster_health`: Checks the health status of the Wazuh Manager cluster.
|
||||
// - `get_wazuh_cluster_nodes`: Lists nodes participating in the Wazuh Manager cluster.
|
||||
//
|
||||
// (Detailed parameters and descriptions for each tool are available via the MCP `get_tools` command or in the server's `get_info` response.)
|
||||
//
|
||||
// Configuration:
|
||||
// The server requires the following environment variables to connect to the Wazuh instance:
|
||||
// - `WAZUH_API_HOST`: Hostname or IP address of the Wazuh API.
|
||||
// - `WAZUH_API_PORT`: Port number for the Wazuh API (default: 55000).
|
||||
// - `WAZUH_API_USERNAME`: Username for Wazuh API authentication.
|
||||
// - `WAZUH_API_PASSWORD`: Password for Wazuh API authentication.
|
||||
// - `WAZUH_INDEXER_HOST`: Hostname or IP address of the Wazuh Indexer.
|
||||
// - `WAZUH_INDEXER_PORT`: Port number for the Wazuh Indexer API (default: 9200).
|
||||
// - `WAZUH_INDEXER_USERNAME`: Username for Wazuh Indexer authentication.
|
||||
// - `WAZUH_INDEXER_PASSWORD`: Password for Wazuh Indexer authentication.
|
||||
// - `WAZUH_VERIFY_SSL`: Set to "true" to enable SSL certificate verification, "false" otherwise (default: false).
|
||||
// - `WAZUH_TEST_PROTOCOL`: (Optional) Protocol to use for Wazuh API/Indexer connections, e.g., "http" or "https" (default: "https").
|
||||
// Logging behavior is controlled by the `RUST_LOG` environment variable (e.g., `RUST_LOG=info,mcp_server_wazuh=debug`).
|
||||
|
||||
use clap::Parser;
|
||||
use dotenv::dotenv;
|
||||
use rmcp::{
|
||||
model::{CallToolResult, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo},
|
||||
tool,
|
||||
transport::stdio,
|
||||
Error as McpError, ServerHandler, ServiceExt,
|
||||
};
|
||||
use std::env;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use dotenv::dotenv;
|
||||
use std::backtrace::Backtrace;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tracing::{debug, error, info, Level};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use wazuh_client::WazuhClientFactory;
|
||||
|
||||
// Use components from the library crate
|
||||
use mcp_server_wazuh::http_service::create_http_router;
|
||||
use mcp_server_wazuh::stdio_service::run_stdio_service;
|
||||
use mcp_server_wazuh::wazuh::client::WazuhIndexerClient;
|
||||
use mcp_server_wazuh::AppState;
|
||||
mod tools;
|
||||
use tools::agents::{AgentTools, GetAgentPortsParams, GetAgentProcessesParams, GetAgentsParams};
|
||||
use tools::alerts::{AlertTools, GetAlertSummaryParams};
|
||||
use tools::rules::{GetRulesSummaryParams, RuleTools};
|
||||
use tools::stats::{
|
||||
GetClusterHealthParams, GetClusterNodesParams, GetLogCollectorStatsParams,
|
||||
GetManagerErrorLogsParams, GetRemotedStatsParams, GetWeeklyStatsParams,
|
||||
SearchManagerLogsParams, StatsTools,
|
||||
};
|
||||
use tools::vulnerabilities::{
|
||||
GetCriticalVulnerabilitiesParams, GetVulnerabilitiesSummaryParams, VulnerabilityTools,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenv().ok();
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "mcp-server-wazuh")]
|
||||
#[command(about = "Wazuh SIEM MCP Server")]
|
||||
struct Args {
|
||||
// Currently only stdio transport is supported
|
||||
// Future versions may add HTTP-SSE transport
|
||||
}
|
||||
|
||||
// Set a custom panic hook to ensure panics are logged
|
||||
std::panic::set_hook(Box::new(|panic_info| {
|
||||
// Using eprintln directly as tracing might not be available or working during a panic
|
||||
eprintln!(
|
||||
"\n================================================================================\n"
|
||||
);
|
||||
eprintln!("PANIC OCCURRED IN MCP SERVER");
|
||||
eprintln!(
|
||||
"\n--------------------------------------------------------------------------------\n"
|
||||
);
|
||||
eprintln!("Panic Info: {:#?}", panic_info);
|
||||
eprintln!(
|
||||
"\n--------------------------------------------------------------------------------\n"
|
||||
);
|
||||
// Capture and print the backtrace
|
||||
// Requires RUST_BACKTRACE=1 (or full) to be set in the environment
|
||||
let backtrace = Backtrace::capture();
|
||||
eprintln!("Backtrace:\n{:?}", backtrace);
|
||||
eprintln!(
|
||||
"\n================================================================================\n"
|
||||
);
|
||||
#[derive(Clone)]
|
||||
struct WazuhToolsServer {
|
||||
agent_tools: AgentTools,
|
||||
alert_tools: AlertTools,
|
||||
rule_tools: RuleTools,
|
||||
stats_tools: StatsTools,
|
||||
vulnerability_tools: VulnerabilityTools,
|
||||
}
|
||||
|
||||
// If tracing is still operational, try to log with it too.
|
||||
// This might not always work if the panic is deep within tracing or stdio.
|
||||
error!(panic_info = %panic_info, backtrace = ?backtrace, "Global panic hook caught a panic");
|
||||
}));
|
||||
debug!("Custom panic hook set.");
|
||||
#[tool(tool_box)]
|
||||
impl WazuhToolsServer {
|
||||
fn new() -> Result<Self, anyhow::Error> {
|
||||
dotenv().ok();
|
||||
|
||||
// Configure tracing to output to stderr
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(Level::DEBUG.into()))
|
||||
.init();
|
||||
let api_host = env::var("WAZUH_API_HOST").unwrap_or_else(|_| "localhost".to_string());
|
||||
let api_port: u16 = env::var("WAZUH_API_PORT")
|
||||
.unwrap_or_else(|_| "55000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(55000);
|
||||
let api_username = env::var("WAZUH_API_USERNAME").unwrap_or_else(|_| "wazuh".to_string());
|
||||
let api_password = env::var("WAZUH_API_PASSWORD").unwrap_or_else(|_| "wazuh".to_string());
|
||||
|
||||
info!("Starting Wazuh MCP Server");
|
||||
let indexer_host =
|
||||
env::var("WAZUH_INDEXER_HOST").unwrap_or_else(|_| "localhost".to_string());
|
||||
let indexer_port: u16 = env::var("WAZUH_INDEXER_PORT")
|
||||
.unwrap_or_else(|_| "9200".to_string())
|
||||
.parse()
|
||||
.unwrap_or(9200);
|
||||
let indexer_username =
|
||||
env::var("WAZUH_INDEXER_USERNAME").unwrap_or_else(|_| "admin".to_string());
|
||||
let indexer_password =
|
||||
env::var("WAZUH_INDEXER_PASSWORD").unwrap_or_else(|_| "admin".to_string());
|
||||
|
||||
debug!("Loading environment variables...");
|
||||
let wazuh_host = env::var("WAZUH_HOST").unwrap_or_else(|_| "localhost".to_string());
|
||||
let wazuh_port = env::var("WAZUH_PORT")
|
||||
.unwrap_or_else(|_| "9200".to_string())
|
||||
.parse::<u16>()
|
||||
.expect("WAZUH_PORT must be a valid port number");
|
||||
let wazuh_user = env::var("WAZUH_USER").unwrap_or_else(|_| "admin".to_string());
|
||||
let wazuh_pass = env::var("WAZUH_PASS").unwrap_or_else(|_| "admin".to_string());
|
||||
let verify_ssl = env::var("VERIFY_SSL")
|
||||
.unwrap_or_else(|_| "false".to_string())
|
||||
.to_lowercase()
|
||||
== "true";
|
||||
let mcp_server_port = env::var("MCP_SERVER_PORT")
|
||||
.unwrap_or_else(|_| "8000".to_string())
|
||||
.parse::<u16>()
|
||||
.expect("MCP_SERVER_PORT must be a valid port number");
|
||||
let verify_ssl = env::var("WAZUH_VERIFY_SSL")
|
||||
.unwrap_or_else(|_| "false".to_string())
|
||||
.parse()
|
||||
.unwrap_or(false);
|
||||
|
||||
debug!(
|
||||
wazuh_host,
|
||||
wazuh_port,
|
||||
wazuh_user,
|
||||
// wazuh_pass is sensitive, avoid logging
|
||||
verify_ssl,
|
||||
mcp_server_port,
|
||||
"Environment variables loaded."
|
||||
);
|
||||
let test_protocol = env::var("WAZUH_TEST_PROTOCOL")
|
||||
.ok()
|
||||
.or_else(|| Some("https".to_string()));
|
||||
|
||||
info!("Initializing Wazuh API client...");
|
||||
let wazuh_client = WazuhIndexerClient::new(
|
||||
wazuh_host.clone(),
|
||||
wazuh_port,
|
||||
wazuh_user.clone(),
|
||||
wazuh_pass.clone(),
|
||||
verify_ssl,
|
||||
);
|
||||
let mut builder = WazuhClientFactory::builder()
|
||||
.api_host(api_host)
|
||||
.api_port(api_port)
|
||||
.api_credentials(&api_username, &api_password)
|
||||
.indexer_host(indexer_host)
|
||||
.indexer_port(indexer_port)
|
||||
.indexer_credentials(&indexer_username, &indexer_password)
|
||||
.verify_ssl(verify_ssl);
|
||||
|
||||
let app_state = Arc::new(AppState {
|
||||
wazuh_client: Mutex::new(wazuh_client),
|
||||
});
|
||||
debug!("AppState created.");
|
||||
|
||||
// Set up HTTP routes using the new http_service module
|
||||
info!("Setting up HTTP routes...");
|
||||
let app = create_http_router(app_state.clone());
|
||||
debug!("HTTP routes configured.");
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], mcp_server_port));
|
||||
info!("Attempting to bind HTTP server to {}", addr);
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!("Failed to bind to address {}: {}", addr, e);
|
||||
panic!("Failed to bind to address {}: {}", addr, e);
|
||||
});
|
||||
info!("Wazuh MCP Server listening on {}", addr);
|
||||
|
||||
// Spawn the stdio transport handler using the new stdio_service
|
||||
info!("Spawning stdio service handler...");
|
||||
let app_state_for_stdio = app_state.clone();
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
|
||||
|
||||
let stdio_handle = tokio::spawn(async move {
|
||||
run_stdio_service(app_state_for_stdio, shutdown_tx).await;
|
||||
info!("run_stdio_service ASYNC TASK has completed its execution.");
|
||||
});
|
||||
|
||||
// Configure Axum with graceful shutdown
|
||||
let axum_shutdown_signal = async {
|
||||
shutdown_rx
|
||||
.await
|
||||
.map_err(|e| error!("Shutdown signal sender dropped: {}", e))
|
||||
.ok(); // Wait for the signal, log if sender is dropped
|
||||
info!("Graceful shutdown signal received for Axum server. Axum will now attempt to shut down.");
|
||||
};
|
||||
|
||||
info!("Starting Axum server with graceful shutdown.");
|
||||
let axum_task = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(axum_shutdown_signal)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
error!("Axum Server run error: {}", e);
|
||||
});
|
||||
info!("Axum server task has completed and shut down.");
|
||||
});
|
||||
|
||||
// Make handles mutable so select can take &mut
|
||||
let mut stdio_handle = stdio_handle;
|
||||
let mut axum_task = axum_task;
|
||||
|
||||
// Wait for either the stdio service or Axum server to complete.
|
||||
tokio::select! {
|
||||
biased; // Prioritize checking stdio_handle first if both are ready
|
||||
|
||||
stdio_res = &mut stdio_handle => {
|
||||
match stdio_res {
|
||||
Ok(_) => info!("Stdio service task completed. Axum's graceful shutdown should have been triggered if stdio initiated it."),
|
||||
Err(e) => error!("Stdio service task failed or panicked: {:?}", e),
|
||||
}
|
||||
// Stdio has finished. If it didn't send a shutdown signal (e.g. due to panic before sending),
|
||||
// Axum might still be running. The shutdown_tx being dropped will also trigger axum_shutdown_signal.
|
||||
info!("Waiting for Axum server to fully shut down after stdio completion...");
|
||||
match axum_task.await {
|
||||
Ok(_) => info!("Axum server task completed successfully after stdio completion."),
|
||||
Err(e) => error!("Axum server task failed or panicked after stdio completion: {:?}", e),
|
||||
}
|
||||
if let Some(protocol) = test_protocol {
|
||||
builder = builder.protocol(protocol);
|
||||
}
|
||||
|
||||
let wazuh_factory = builder.build();
|
||||
|
||||
axum_res = &mut axum_task => {
|
||||
match axum_res {
|
||||
Ok(_) => info!("Axum server task completed (possibly due to graceful shutdown or error)."),
|
||||
Err(e) => error!("Axum server task failed or panicked: {:?}", e),
|
||||
}
|
||||
// Axum has finished. The main function will now exit.
|
||||
// We should wait for stdio_handle to complete or be cancelled.
|
||||
info!("Axum finished. Waiting for stdio_handle to complete or be cancelled...");
|
||||
match stdio_handle.await {
|
||||
Ok(_) => info!("Stdio service task also completed after Axum finished."),
|
||||
Err(e) => {
|
||||
if e.is_cancelled() {
|
||||
info!("Stdio service task was cancelled after Axum finished (expected if main is exiting).");
|
||||
} else {
|
||||
error!("Stdio service task failed or panicked after Axum finished: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
let wazuh_indexer_client = wazuh_factory.create_indexer_client();
|
||||
let wazuh_rules_client = wazuh_factory.create_rules_client();
|
||||
let wazuh_vulnerability_client = wazuh_factory.create_vulnerability_client();
|
||||
let wazuh_agents_client = wazuh_factory.create_agents_client();
|
||||
let wazuh_logs_client = wazuh_factory.create_logs_client();
|
||||
let wazuh_cluster_client = wazuh_factory.create_cluster_client();
|
||||
|
||||
let indexer_client_arc = Arc::new(wazuh_indexer_client);
|
||||
let rules_client_arc = Arc::new(tokio::sync::Mutex::new(wazuh_rules_client));
|
||||
let vulnerability_client_arc =
|
||||
Arc::new(tokio::sync::Mutex::new(wazuh_vulnerability_client));
|
||||
let agents_client_arc = Arc::new(tokio::sync::Mutex::new(wazuh_agents_client));
|
||||
let logs_client_arc = Arc::new(tokio::sync::Mutex::new(wazuh_logs_client));
|
||||
let cluster_client_arc = Arc::new(tokio::sync::Mutex::new(wazuh_cluster_client));
|
||||
|
||||
let agent_tools =
|
||||
AgentTools::new(agents_client_arc.clone(), vulnerability_client_arc.clone());
|
||||
let alert_tools = AlertTools::new(indexer_client_arc.clone());
|
||||
let rule_tools = RuleTools::new(rules_client_arc.clone());
|
||||
let stats_tools = StatsTools::new(logs_client_arc.clone(), cluster_client_arc.clone());
|
||||
let vulnerability_tools = VulnerabilityTools::new(vulnerability_client_arc.clone());
|
||||
|
||||
Ok(Self {
|
||||
agent_tools,
|
||||
alert_tools,
|
||||
rule_tools,
|
||||
stats_tools,
|
||||
vulnerability_tools,
|
||||
})
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_alert_summary",
|
||||
description = "Retrieves a summary of Wazuh security alerts. Returns formatted alert information including ID, timestamp, and description."
|
||||
)]
|
||||
async fn get_wazuh_alert_summary(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAlertSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.alert_tools.get_wazuh_alert_summary(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_rules_summary",
|
||||
description = "Retrieves a summary of Wazuh security rules. Returns formatted rule information including ID, level, description, and groups. Supports filtering by level, group, and filename."
|
||||
)]
|
||||
async fn get_wazuh_rules_summary(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetRulesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.rule_tools.get_wazuh_rules_summary(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_vulnerability_summary",
|
||||
description = "Retrieves a summary of Wazuh vulnerability detections for a specific agent. Returns formatted vulnerability information including CVE ID, severity, detection time, and agent details. Supports filtering by severity level."
|
||||
)]
|
||||
async fn get_wazuh_vulnerability_summary(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetVulnerabilitiesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.vulnerability_tools
|
||||
.get_wazuh_vulnerability_summary(params)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_critical_vulnerabilities",
|
||||
description = "Retrieves critical vulnerabilities for a specific Wazuh agent. Returns formatted vulnerability information including CVE ID, title, description, CVSS scores, and detection details. Only shows vulnerabilities with 'Critical' severity level."
|
||||
)]
|
||||
async fn get_wazuh_critical_vulnerabilities(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetCriticalVulnerabilitiesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.vulnerability_tools
|
||||
.get_wazuh_critical_vulnerabilities(params)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_agents",
|
||||
description = "Retrieves a list of Wazuh agents with their current status and details. Returns formatted agent information including ID, name, IP, status, OS details, and last activity. Supports filtering by status, name, IP, group, OS platform, and version."
|
||||
)]
|
||||
async fn get_wazuh_agents(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAgentsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.agent_tools.get_wazuh_agents(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_agent_processes",
|
||||
description = "Retrieves a list of running processes for a specific Wazuh agent. Returns formatted process information including PID, name, state, user, and command. Supports filtering by process name/command."
|
||||
)]
|
||||
async fn get_wazuh_agent_processes(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAgentProcessesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.agent_tools.get_wazuh_agent_processes(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_cluster_health",
|
||||
description = "Checks the health of the Wazuh cluster. Returns whether the cluster is enabled, running, and if nodes are connected."
|
||||
)]
|
||||
async fn get_wazuh_cluster_health(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetClusterHealthParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_cluster_health(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_cluster_nodes",
|
||||
description = "Retrieves a list of nodes in the Wazuh cluster. Returns formatted node information including name, type, version, IP, and status. Supports filtering by limit, offset, and node type."
|
||||
)]
|
||||
async fn get_wazuh_cluster_nodes(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetClusterNodesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_cluster_nodes(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "search_wazuh_manager_logs",
|
||||
description = "Searches Wazuh manager logs. Returns formatted log entries including timestamp, tag, level, and description. Supports filtering by limit, offset, level, tag, and a search term."
|
||||
)]
|
||||
async fn search_wazuh_manager_logs(
|
||||
&self,
|
||||
#[tool(aggr)] params: SearchManagerLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.search_wazuh_manager_logs(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_manager_error_logs",
|
||||
description = "Retrieves Wazuh manager error logs. Returns formatted log entries including timestamp, tag, level (error), and description."
|
||||
)]
|
||||
async fn get_wazuh_manager_error_logs(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetManagerErrorLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_manager_error_logs(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_log_collector_stats",
|
||||
description = "Retrieves log collector statistics for a specific Wazuh agent. Returns information about events processed, dropped, bytes, and target log files."
|
||||
)]
|
||||
async fn get_wazuh_log_collector_stats(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetLogCollectorStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_log_collector_stats(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_remoted_stats",
|
||||
description = "Retrieves statistics from the Wazuh remoted daemon. Returns information about queue size, TCP sessions, event counts, and message traffic."
|
||||
)]
|
||||
async fn get_wazuh_remoted_stats(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetRemotedStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_remoted_stats(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_agent_ports",
|
||||
description = "Retrieves a list of open network ports for a specific Wazuh agent. Returns formatted port information including local/remote IP and port, protocol, state, and associated process/PID. Supports filtering by protocol and state."
|
||||
)]
|
||||
async fn get_wazuh_agent_ports(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAgentPortsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.agent_tools.get_wazuh_agent_ports(params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_weekly_stats",
|
||||
description = "Retrieves weekly statistics from the Wazuh manager. Returns a JSON object detailing various metrics aggregated over the past week."
|
||||
)]
|
||||
async fn get_wazuh_weekly_stats(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetWeeklyStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
self.stats_tools.get_wazuh_weekly_stats(params).await
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(tool_box)]
|
||||
impl ServerHandler for WazuhToolsServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_prompts()
|
||||
.enable_resources()
|
||||
.enable_tools()
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some(
|
||||
"This server provides tools to interact with a Wazuh SIEM instance for security monitoring and analysis.\n\
|
||||
Available tools:\n\
|
||||
- 'get_wazuh_alert_summary': Retrieves a summary of Wazuh security alerts. \
|
||||
Optionally takes 'limit' parameter to control the number of alerts returned (defaults to 100).\n\
|
||||
- 'get_wazuh_rules_summary': Retrieves a summary of Wazuh security rules. \
|
||||
Supports filtering by 'level', 'group', and 'filename' parameters, with 'limit' to control the number of rules returned (defaults to 100).\n\
|
||||
- 'get_wazuh_vulnerability_summary': Retrieves a summary of Wazuh vulnerability detections for a specific agent. \
|
||||
Requires an 'agent_id' parameter. This must be provided as a string, representing the numeric ID of the agent (e.g., \"0\", \"1\", \"12\", \"001\", \"012\"). The server will automatically format this string into a three-digit, zero-padded identifier. For instance, an input of \"0\" will be treated as \"000\", \"1\" as \"001\", and \"12\" as \"012\". Supports filtering by 'severity' and 'cve' parameters, with 'limit' to control the number of vulnerabilities returned (defaults to 100).\n\
|
||||
- 'get_wazuh_critical_vulnerabilities': Retrieves only critical vulnerabilities for a specific agent. \
|
||||
Requires an 'agent_id' parameter. This must be provided as a string, representing the numeric ID of the agent (e.g., \"0\", \"1\", \"12\", \"001\", \"012\"). The server will automatically format this string into a three-digit, zero-padded identifier. For instance, an input of \"0\" will be treated as \"000\", \"1\" as \"001\", and \"12\" as \"012\". Returns detailed information about vulnerabilities with 'Critical' severity level.\n\
|
||||
- 'get_wazuh_running_agents': Retrieves a list of Wazuh agents with their current status and details. \
|
||||
Supports filtering by 'status' (active, disconnected, pending, never_connected), 'name', 'ip', 'group', 'os_platform', and 'version' parameters, with 'limit' to control the number of agents returned (defaults to 100, status defaults to 'active').\n\
|
||||
- 'get_wazuh_agent_processes': Retrieves a list of running processes for a specific Wazuh agent. \
|
||||
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Supports 'limit' (default 100) and 'search' (to filter by process name or command line) parameters.\n\
|
||||
- 'get_wazuh_agent_ports': Retrieves a list of open network ports for a specific Wazuh agent. \
|
||||
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Supports 'limit' (default 100), 'protocol' (e.g., \"tcp\", \"udp\"), and 'state' (e.g., \"LISTENING\", \"ESTABLISHED\") parameters to filter the results. Note: State filtering is performed client-side by this server.\n\
|
||||
The 'state' parameter filters results:
|
||||
- If 'state' is 'LISTENING' (case-insensitive): Only ports explicitly in the 'LISTENING' state are returned. Ports with other states, no state, or an empty state string are filtered out.
|
||||
- If 'state' is any other value (e.g., 'ESTABLISHED'): Ports that are *not* in the 'LISTENING' state are returned. This includes ports with other defined states (like 'ESTABLISHED', 'TIME_WAIT', etc.) and ports that have *no state* defined. Ports with an empty state string are always filtered out.
|
||||
Note: State filtering is performed client-side by this server. \
|
||||
- 'search_wazuh_manager_logs': Searches Wazuh manager logs. \
|
||||
Optional parameters: 'limit' (default 100), 'offset' (default 0), 'level' (e.g., \"error\", \"info\"), 'tag' (e.g., \"wazuh-modulesd\"), 'search_term' (for free-text search in log descriptions).\n\
|
||||
- 'get_wazuh_manager_error_logs': Retrieves Wazuh manager error logs. \
|
||||
Optional parameter: 'limit' (default 100).\n\
|
||||
- 'get_wazuh_log_collector_stats': Retrieves log collector statistics for a specific Wazuh agent. \
|
||||
Requires an 'agent_id' parameter (formatted as described for other agent-specific tools). Returns detailed information for 'global' and 'interval' periods, including start/end times, and for each log file: location, events processed, bytes, and target-specific drop counts.\n\
|
||||
- 'get_wazuh_remoted_stats': Retrieves statistics from the Wazuh remoted daemon (manager-wide).\n\
|
||||
- 'get_wazuh_weekly_stats': Retrieves weekly statistics from the Wazuh manager. Returns a JSON object detailing various metrics aggregated over the past week. No parameters required.\n\
|
||||
- 'get_wazuh_cluster_health': Checks the health of the Wazuh cluster. Returns a textual summary of the cluster's health status (e.g., enabled, running, connected nodes). No parameters required.\n\
|
||||
- 'get_wazuh_cluster_nodes': Retrieves a list of nodes in the Wazuh cluster. \
|
||||
Optional parameters: 'limit' (max nodes, API default 500), 'offset' (default 0), 'node_type' (e.g., \"master\", \"worker\")."
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
info!("Main function is exiting.");
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _args = Args::parse();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::from_default_env()
|
||||
.add_directive(tracing::Level::DEBUG.into()),
|
||||
)
|
||||
.with_writer(std::io::stderr)
|
||||
.init();
|
||||
|
||||
tracing::info!("Starting Wazuh MCP Server...");
|
||||
|
||||
// Create an instance of our Wazuh tools server
|
||||
let server = WazuhToolsServer::new().expect("Error initializing Wazuh tools server");
|
||||
|
||||
tracing::info!("Using stdio transport");
|
||||
let service = server.serve(stdio()).await.inspect_err(|e| {
|
||||
tracing::error!("serving error: {:?}", e);
|
||||
})?;
|
||||
|
||||
service.waiting().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,425 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum McpClientError {
|
||||
#[error("HTTP request error: {0}")]
|
||||
HttpRequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("HTTP API error: status {status}, message: {message}")]
|
||||
HttpApiError {
|
||||
status: reqwest::StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
#[error("JSON serialization/deserialization error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("Failed to spawn child process: {0}")]
|
||||
ProcessSpawnError(String),
|
||||
|
||||
#[error("Child process stdin/stdout not available")]
|
||||
ProcessPipeError,
|
||||
|
||||
#[error("JSON-RPC error: code {code}, message: {message}, data: {data:?}")]
|
||||
JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
data: Option<Value>,
|
||||
},
|
||||
|
||||
#[error("Received unexpected JSON-RPC response: {0}")]
|
||||
UnexpectedResponse(String),
|
||||
|
||||
#[error("Operation timed out")]
|
||||
Timeout,
|
||||
|
||||
#[error("Operation not supported in current mode: {0}")]
|
||||
UnsupportedOperation(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpMessage {
|
||||
pub protocol_version: String,
|
||||
pub source: String,
|
||||
pub timestamp: String,
|
||||
pub event_type: String,
|
||||
pub context: Value,
|
||||
pub metadata: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
struct JsonRpcRequest<T: Serialize> {
|
||||
jsonrpc: String,
|
||||
method: String,
|
||||
params: Option<T>,
|
||||
id: Value, // Changed from usize to Value
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct JsonRpcResponse<T> {
|
||||
jsonrpc: String,
|
||||
result: Option<T>,
|
||||
error: Option<JsonRpcErrorData>,
|
||||
id: Value, // Changed from usize to Value
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct JsonRpcErrorData {
|
||||
code: i32,
|
||||
message: String,
|
||||
data: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct InitializeResult {
|
||||
pub protocol_version: String,
|
||||
pub server_info: ServerInfo,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait McpClientTrait {
|
||||
async fn initialize(&mut self) -> Result<InitializeResult, McpClientError>;
|
||||
async fn provide_context(
|
||||
&mut self,
|
||||
params: Option<Value>,
|
||||
) -> Result<Vec<McpMessage>, McpClientError>;
|
||||
async fn shutdown(&mut self) -> Result<(), McpClientError>;
|
||||
}
|
||||
|
||||
enum ClientMode {
|
||||
Http {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
},
|
||||
Stdio {
|
||||
stdin: ChildStdin,
|
||||
stdout: BufReader<ChildStdout>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct McpClient {
|
||||
mode: ClientMode,
|
||||
child_process: Option<Child>,
|
||||
request_id_counter: AtomicUsize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClientTrait for McpClient {
|
||||
async fn initialize(&mut self) -> Result<InitializeResult, McpClientError> {
|
||||
match &mut self.mode {
|
||||
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
|
||||
"initialize is not supported in HTTP mode".to_string(),
|
||||
)),
|
||||
ClientMode::Stdio { .. } => {
|
||||
let request_id = self.next_id();
|
||||
self.send_stdio_request("initialize", None::<()>, request_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn provide_context(
|
||||
&mut self,
|
||||
params: Option<Value>,
|
||||
) -> Result<Vec<McpMessage>, McpClientError> {
|
||||
match &mut self.mode {
|
||||
ClientMode::Http { client, base_url } => {
|
||||
let url = format!("{}/mcp", base_url);
|
||||
let request_builder = if let Some(p) = params {
|
||||
client.post(&url).json(&p)
|
||||
} else {
|
||||
client.get(&url)
|
||||
};
|
||||
let response = request_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(McpClientError::HttpRequestError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let message = response.text().await.unwrap_or_else(|_| {
|
||||
format!("Failed to get error body for status {}", status)
|
||||
});
|
||||
return Err(McpClientError::HttpApiError { status, message });
|
||||
}
|
||||
response
|
||||
.json::<Vec<McpMessage>>()
|
||||
.await
|
||||
.map_err(McpClientError::HttpRequestError)
|
||||
}
|
||||
ClientMode::Stdio { .. } => {
|
||||
let request_id = self.next_id();
|
||||
self.send_stdio_request("provideContext", params, request_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> Result<(), McpClientError> {
|
||||
match &mut self.mode {
|
||||
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
|
||||
"shutdown is not supported in HTTP mode".to_string(),
|
||||
)),
|
||||
ClientMode::Stdio { .. } => {
|
||||
let request_id = self.next_id();
|
||||
// Attempt to send shutdown command, ignore error if server already closed pipe
|
||||
let _result: Result<Option<Value>, McpClientError> = self
|
||||
.send_stdio_request("shutdown", None::<()>, request_id)
|
||||
.await;
|
||||
// Always try to clean up the process
|
||||
self.close_stdio_process().await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
pub fn new_http(base_url: String) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
Self {
|
||||
mode: ClientMode::Http { client, base_url },
|
||||
child_process: None,
|
||||
request_id_counter: AtomicUsize::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_stdio(
|
||||
executable_path: &str,
|
||||
envs: Option<Vec<(String, String)>>,
|
||||
) -> Result<Self, McpClientError> {
|
||||
let mut command = Command::new(executable_path);
|
||||
command.stdin(std::process::Stdio::piped());
|
||||
command.stdout(std::process::Stdio::piped());
|
||||
command.stderr(std::process::Stdio::inherit()); // Pipe child's stderr to parent's stderr for visibility
|
||||
|
||||
if let Some(env_vars) = envs {
|
||||
for (key, value) in env_vars {
|
||||
command.env(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| McpClientError::ProcessSpawnError(e.to_string()))?;
|
||||
|
||||
let stdin = child.stdin.take().ok_or(McpClientError::ProcessPipeError)?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or(McpClientError::ProcessPipeError)?;
|
||||
|
||||
Ok(Self {
|
||||
mode: ClientMode::Stdio {
|
||||
stdin,
|
||||
stdout: BufReader::new(stdout),
|
||||
},
|
||||
child_process: Some(child),
|
||||
request_id_counter: AtomicUsize::new(1),
|
||||
})
|
||||
}
|
||||
|
||||
fn next_id(&self) -> Value {
|
||||
Value::from(self.request_id_counter.fetch_add(1, Ordering::SeqCst))
|
||||
}
|
||||
|
||||
async fn send_stdio_request<P: Serialize, R: DeserializeOwned>(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<P>,
|
||||
id: Value, // Added id parameter
|
||||
) -> Result<R, McpClientError> {
|
||||
// Removed: let request_id = self.next_id();
|
||||
let rpc_request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: method.to_string(),
|
||||
params,
|
||||
id: id.clone(), // Use the provided id
|
||||
};
|
||||
let request_json = serde_json::to_string(&rpc_request)? + "\n";
|
||||
|
||||
let (stdin, stdout) = match &mut self.mode {
|
||||
ClientMode::Stdio { stdin, stdout } => (stdin, stdout),
|
||||
ClientMode::Http { .. } => {
|
||||
return Err(McpClientError::UnsupportedOperation(
|
||||
"send_stdio_request is only for Stdio mode".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
stdin.write_all(request_json.as_bytes()).await?;
|
||||
stdin.flush().await?;
|
||||
|
||||
let mut response_json = String::new();
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
stdout.read_line(&mut response_json),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(0)) => {
|
||||
return Err(McpClientError::IoError(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"Server closed stdout",
|
||||
)))
|
||||
}
|
||||
Ok(Ok(_)) => { /* continue */ }
|
||||
Ok(Err(e)) => return Err(McpClientError::IoError(e)),
|
||||
Err(_) => return Err(McpClientError::Timeout),
|
||||
}
|
||||
|
||||
let rpc_response: JsonRpcResponse<R> = serde_json::from_str(response_json.trim())?;
|
||||
|
||||
// Compare Value IDs. Note: Value implements PartialEq.
|
||||
if rpc_response.id != id {
|
||||
return Err(McpClientError::UnexpectedResponse(format!(
|
||||
"Mismatched request/response IDs. Expected {}, got {}. Response: '{}'",
|
||||
id, rpc_response.id, response_json
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(err_data) = rpc_response.error {
|
||||
return Err(McpClientError::JsonRpcError {
|
||||
code: err_data.code,
|
||||
message: err_data.message,
|
||||
data: err_data.data,
|
||||
});
|
||||
}
|
||||
|
||||
rpc_response.result.ok_or_else(|| {
|
||||
McpClientError::UnexpectedResponse("Missing result in JSON-RPC response".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
async fn close_stdio_process(&mut self) -> Result<(), McpClientError> {
|
||||
if let Some(mut child) = self.child_process.take() {
|
||||
child.kill().await.map_err(McpClientError::IoError)?;
|
||||
let _ = child.wait().await; // Ensure process is reaped
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// New public method for sending generic JSON-RPC requests
|
||||
pub async fn send_json_rpc_request(
|
||||
&mut self,
|
||||
method: &str,
|
||||
params: Option<Value>,
|
||||
id: Value,
|
||||
) -> Result<Value, McpClientError> {
|
||||
match &mut self.mode {
|
||||
ClientMode::Http { .. } => Err(McpClientError::UnsupportedOperation(
|
||||
"Generic JSON-RPC calls are not supported in HTTP mode by this client.".to_string(),
|
||||
)),
|
||||
ClientMode::Stdio { .. } => {
|
||||
// R (result type) is Value for generic calls
|
||||
self.send_stdio_request(method, params, id).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
use tokio;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_http_get_data() {
|
||||
// Renamed to be specific
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/mcp");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!([
|
||||
{
|
||||
"protocol_version": "1.0",
|
||||
"source": "Wazuh",
|
||||
"timestamp": "2023-05-01T12:00:00Z",
|
||||
"event_type": "alert",
|
||||
"context": {
|
||||
"id": "12345",
|
||||
"category": "intrusion_detection",
|
||||
"severity": "high",
|
||||
"description": "Test alert",
|
||||
"data": { "source_ip": "192.168.1.100" }
|
||||
},
|
||||
"metadata": { "integration": "Wazuh-MCP", "notes": "Test note" }
|
||||
}
|
||||
]));
|
||||
});
|
||||
|
||||
let mut client = McpClient::new_http(server.url("")); // Use new_http
|
||||
|
||||
let result = client.provide_context(None).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].protocol_version, "1.0");
|
||||
assert_eq!(result[0].source, "Wazuh");
|
||||
assert_eq!(result[0].event_type, "alert");
|
||||
|
||||
let context = &result[0].context;
|
||||
assert_eq!(context["id"], "12345");
|
||||
assert_eq!(context["category"], "intrusion_detection");
|
||||
assert_eq!(context["severity"], "high");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_http_health_check_equivalent() {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/health"); // Assuming /health is still the target for this test
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"status": "ok",
|
||||
"service": "wazuh-mcp-server",
|
||||
"timestamp": "2023-05-01T12:00:00Z"
|
||||
}));
|
||||
});
|
||||
|
||||
let client = McpClient::new_http(server.url(""));
|
||||
|
||||
// The new trait doesn't have a direct "check_health".
|
||||
// If `initialize` was to be used for HTTP health, it would be:
|
||||
// let result = client.initialize().await;
|
||||
// But initialize is Stdio-only. So this test needs to adapt or be removed
|
||||
// if there's no direct equivalent in the new trait for HTTP health.
|
||||
// For now, let's assume we might add a specific http_health method if needed,
|
||||
// or this test is demonstrating a capability that's no longer directly on the trait.
|
||||
// To make this test pass with current structure, we'd need a separate HTTP health method.
|
||||
// Let's simulate calling the /health endpoint directly if that's the intent.
|
||||
let http_client = reqwest::Client::new();
|
||||
let response = http_client.get(server.url("/health")).send().await.unwrap();
|
||||
assert_eq!(response.status(), reqwest::StatusCode::OK);
|
||||
let health_data: Value = response.json().await.unwrap();
|
||||
assert_eq!(health_data["status"], "ok");
|
||||
assert_eq!(health_data["service"], "wazuh-mcp-server");
|
||||
}
|
||||
|
||||
// TODO: Add tests for Stdio mode. This would require a mock executable
|
||||
// or a more complex test setup. For now, focusing on the client structure.
|
||||
}
|
@ -1,456 +0,0 @@
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::mcp::protocol::{error_codes, JsonRpcError, JsonRpcRequest, JsonRpcResponse};
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
struct ToolCallParams {
|
||||
#[serde(rename = "name")]
|
||||
name: String,
|
||||
#[serde(rename = "arguments")]
|
||||
arguments: Option<Value>, // Input parameters for the specific tool
|
||||
#[serde(flatten)]
|
||||
_extra: std::collections::HashMap<String, Value>,
|
||||
}
|
||||
|
||||
pub struct McpServerCore {
|
||||
app_state: Arc<AppState>,
|
||||
}
|
||||
|
||||
impl McpServerCore {
|
||||
pub fn new(app_state: Arc<AppState>) -> Self {
|
||||
Self { app_state }
|
||||
}
|
||||
|
||||
pub async fn process_request(&self, request: JsonRpcRequest) -> String {
|
||||
info!("Processing request: method={}", request.method);
|
||||
|
||||
let response = match request.method.as_str() {
|
||||
"initialize" => self.handle_initialize(request).await,
|
||||
"shutdown" => self.handle_shutdown(request).await,
|
||||
"provideContext" => self.handle_provide_context(request).await,
|
||||
// Tool methods (prefix "tools/")
|
||||
"tools/list" => self.handle_list_tools(request).await,
|
||||
"tools/call" => self.handle_tool_call(request).await, // Use generic tool call handler
|
||||
// "tools/wazuhAlerts" => self.handle_wazuh_alerts_tool(request).await,
|
||||
// Resource methods (prefix "resources/")
|
||||
"resources/list" => self.handle_get_resources(request).await,
|
||||
"resources/read" => self.handle_read_resource(request).await,
|
||||
// Prompt methods (prefix "prompts/")
|
||||
"prompts/list" => self.handle_list_prompts(request).await,
|
||||
_ => {
|
||||
error!("Method not found: {}", request.method);
|
||||
self.create_error_response(
|
||||
error_codes::METHOD_NOT_FOUND,
|
||||
format!("Method '{}' not found", request.method),
|
||||
None,
|
||||
request.id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
pub fn handle_parse_error(&self, error: serde_json::Error, raw_request: &str) -> String {
|
||||
error!("Failed to parse JSON-RPC request: {}", error);
|
||||
|
||||
// Try to extract the ID from the raw request if possible
|
||||
let id = serde_json::from_str::<Value>(raw_request)
|
||||
.and_then(|v| {
|
||||
if let Some(id) = v.get("id") {
|
||||
Ok(id.clone())
|
||||
} else {
|
||||
// Use a different approach since custom is not available
|
||||
Err(serde_json::Error::io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"No ID field found",
|
||||
)))
|
||||
}
|
||||
})
|
||||
.unwrap_or(Value::Null);
|
||||
|
||||
self.create_error_response(
|
||||
error_codes::PARSE_ERROR,
|
||||
format!("Parse error: {}", error),
|
||||
None,
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_initialize(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling initialize request");
|
||||
|
||||
|
||||
let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition {
|
||||
name: "wazuhAlertSummary".to_string(),
|
||||
description: Some("Returns a text summary of all Wazuh alerts.".to_string()),
|
||||
// Define a minimal valid input schema (empty object)
|
||||
input_schema: Some(json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})),
|
||||
// No output schema needed as per requirements
|
||||
output_schema: None,
|
||||
};
|
||||
|
||||
let result = crate::mcp::protocol::InitializeResult {
|
||||
protocol_version: "2024-11-05".to_string(),
|
||||
capabilities: crate::mcp::protocol::Capabilities {
|
||||
tools: crate::mcp::protocol::ToolCapability {
|
||||
supported: true,
|
||||
definitions: vec![wazuh_alert_summary_tool], // Only include wazuhAlertSummary tool
|
||||
},
|
||||
resources: crate::mcp::protocol::SupportedFeature { supported: true },
|
||||
prompts: crate::mcp::protocol::SupportedFeature { supported: true },
|
||||
},
|
||||
server_info: crate::mcp::protocol::ServerInfo {
|
||||
name: "Wazuh MCP Server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
self.create_success_response(result, request.id)
|
||||
}
|
||||
|
||||
async fn handle_shutdown(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling shutdown request");
|
||||
self.create_success_response(Value::Null, request.id)
|
||||
}
|
||||
|
||||
async fn handle_provide_context(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling provideContext request");
|
||||
|
||||
let mut wazuh_client = self.app_state.wazuh_client.lock().await;
|
||||
|
||||
match wazuh_client.get_alerts().await {
|
||||
Ok(alerts) => {
|
||||
let mcp_messages: Vec<Value> = alerts
|
||||
.into_iter()
|
||||
.map(|alert| crate::mcp::transform::transform_to_mcp(alert, "alert".to_string()))
|
||||
.collect();
|
||||
|
||||
debug!("Transformed {} alerts into MCP messages for provideContext", mcp_messages.len());
|
||||
self.create_success_response(json!(mcp_messages), request.id)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error getting alerts from Wazuh for provideContext: {}", e);
|
||||
self.create_error_response(
|
||||
error_codes::INTERNAL_ERROR,
|
||||
format!("Failed to get alerts from Wazuh: {}", e),
|
||||
None,
|
||||
request.id,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_resources(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling getResources request");
|
||||
// Return an empty list for now
|
||||
let resources_result = crate::mcp::protocol::ResourcesListResult {
|
||||
resources: vec![],
|
||||
};
|
||||
|
||||
self.create_success_response(resources_result, request.id)
|
||||
}
|
||||
|
||||
async fn handle_read_resource(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling readResource request: {:?}", request.params);
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
struct ReadResourceParams {
|
||||
uri: String,
|
||||
// We can add _meta here if needed later
|
||||
// _meta: Option<Value>,
|
||||
}
|
||||
|
||||
let params: ReadResourceParams = match request.params {
|
||||
Some(params_value) => match serde_json::from_value(params_value) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
error!("Failed to parse params for resources/read: {}", e);
|
||||
return self.create_error_response(
|
||||
error_codes::INVALID_PARAMS,
|
||||
format!("Invalid params for resources/read: {}", e),
|
||||
None,
|
||||
request.id,
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
error!("Missing params for resources/read");
|
||||
return self.create_error_response(
|
||||
error_codes::INVALID_PARAMS,
|
||||
"Missing params for resources/read, 'uri' is required".to_string(),
|
||||
None,
|
||||
request.id,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
error!("Unsupported URI for resources/read: {}", params.uri);
|
||||
self.create_error_response(
|
||||
error_codes::INVALID_PARAMS,
|
||||
format!("Unsupported or unknown resource URI: {}", params.uri),
|
||||
None,
|
||||
request.id,
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_tool_call(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling tools/call request: {:?}", request.params);
|
||||
|
||||
|
||||
let params: ToolCallParams = match request.clone().params {
|
||||
Some(params_value) => match serde_json::from_value(params_value) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
error!("Failed to parse params for tools/call: {}", e);
|
||||
return self.create_error_response(
|
||||
error_codes::INVALID_PARAMS,
|
||||
format!("Invalid params for tools/call: {}", e),
|
||||
None,
|
||||
request.id,
|
||||
);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
error!("Missing params for tools/call");
|
||||
return self.create_error_response(
|
||||
error_codes::INVALID_PARAMS,
|
||||
"Missing params for tools/call, 'name' and 'arguments' are required".to_string(),
|
||||
None,
|
||||
request.id,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
match params.name.as_str() {
|
||||
"wazuhAlertSummary" => {
|
||||
info!("Dispatching tools/call to wazuhAlertSummary handler");
|
||||
self.handle_wazuh_alert_summary_tool(request).await
|
||||
}
|
||||
// wazuhAlerts tool is disabled but we keep the handler code
|
||||
_ => {
|
||||
error!("Unsupported tool name requested via tools/call: {}", params.name);
|
||||
self.create_error_response(
|
||||
error_codes::METHOD_NOT_FOUND, // Or a more specific tool error code if available
|
||||
format!("Tool '{}' not found", params.name),
|
||||
None,
|
||||
request.id,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async fn handle_list_tools(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling tools/list request");
|
||||
|
||||
// Define the wazuhAlertSummary tool
|
||||
let wazuh_alert_summary_tool = crate::mcp::protocol::ToolDefinition {
|
||||
name: "wazuhAlertSummary".to_string(),
|
||||
description: Some("Returns a text summary of all Wazuh alerts.".to_string()),
|
||||
// Define a minimal valid input schema (empty object)
|
||||
input_schema: Some(json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})),
|
||||
// No output schema needed as per requirements
|
||||
output_schema: None,
|
||||
};
|
||||
|
||||
let tools_list = crate::mcp::protocol::ToolsListResult {
|
||||
tools: vec![wazuh_alert_summary_tool],
|
||||
};
|
||||
|
||||
self.create_success_response(tools_list, request.id)
|
||||
}
|
||||
|
||||
|
||||
async fn handle_wazuh_alerts_tool(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling tools/wazuhAlerts request. Params: {:?}", request.params);
|
||||
|
||||
let wazuh_client = self.app_state.wazuh_client.lock().await;
|
||||
|
||||
match wazuh_client.get_alerts().await {
|
||||
Ok(raw_alerts) => {
|
||||
let simplified_alerts: Vec<Value> = raw_alerts
|
||||
.into_iter()
|
||||
.map(|alert| {
|
||||
let source = alert.get("_source").unwrap_or(&alert);
|
||||
|
||||
// Extract ID: Try _source.id first, then _id
|
||||
let id = source.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| alert.get("_id").and_then(|v| v.as_str()))
|
||||
.unwrap_or("") // Default to empty string if not found
|
||||
.to_string();
|
||||
|
||||
// Extract Description: Look in _source.rule.description
|
||||
let description = source.get("rule")
|
||||
.and_then(|r| r.get("description"))
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("") // Default to empty string if not found
|
||||
.to_string();
|
||||
|
||||
json!({
|
||||
"id": id,
|
||||
"description": description,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!("Processed {} alerts into simplified format.", simplified_alerts.len());
|
||||
|
||||
// Construct the final result with the "alerts" array
|
||||
let result = json!({
|
||||
"alerts": simplified_alerts,
|
||||
"text": "Hello World",
|
||||
});
|
||||
self.create_success_response(result, request.id)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error getting alerts from Wazuh for tools/wazuhAlerts: {}", e);
|
||||
self.create_error_response(
|
||||
error_codes::INTERNAL_ERROR,
|
||||
format!("Failed to get alerts from Wazuh: {}", e),
|
||||
None,
|
||||
request.id,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_wazuh_alert_summary_tool(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling tools/wazuhAlertSummary request. Params: {:?}", request.params);
|
||||
|
||||
let mut wazuh_client = self.app_state.wazuh_client.lock().await;
|
||||
|
||||
|
||||
match wazuh_client.get_alerts().await {
|
||||
Ok(raw_alerts) => {
|
||||
// Create a content item for each alert
|
||||
let content_items: Vec<Value> = if raw_alerts.is_empty() {
|
||||
// If no alerts, return a single "no alerts" message
|
||||
vec![json!({
|
||||
"type": "text",
|
||||
"text": "No Wazuh alerts found."
|
||||
})]
|
||||
} else {
|
||||
// Map each alert to a content item
|
||||
raw_alerts
|
||||
.into_iter()
|
||||
.map(|alert| {
|
||||
let source = alert.get("_source").unwrap_or(&alert);
|
||||
|
||||
// Extract alert ID
|
||||
let id = source.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| alert.get("_id").and_then(|v| v.as_str()))
|
||||
.unwrap_or("Unknown ID");
|
||||
|
||||
// Extract rule description
|
||||
let description = source.get("rule")
|
||||
.and_then(|r| r.get("description"))
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("No description available");
|
||||
|
||||
// Extract timestamp if available
|
||||
let timestamp = source.get("timestamp")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("Unknown time");
|
||||
|
||||
// Format the alert as a text entry and create a content item
|
||||
json!({
|
||||
"type": "text",
|
||||
"text": format!("Alert ID: {}\nTime: {}\nDescription: {}", id, timestamp, description)
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
debug!("Processed {} alerts into individual content items.", content_items.len());
|
||||
|
||||
// Construct the final result with the content array containing multiple text objects
|
||||
let result = json!({
|
||||
"content": content_items
|
||||
});
|
||||
|
||||
self.create_success_response(result, request.id)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error getting alerts from Wazuh for tools/wazuhAlertSummary: {}", e);
|
||||
self.create_error_response(
|
||||
error_codes::INTERNAL_ERROR,
|
||||
format!("Failed to get alerts from Wazuh: {}", e),
|
||||
None,
|
||||
request.id,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_prompts(&self, request: JsonRpcRequest) -> String {
|
||||
debug!("Handling prompts/list request");
|
||||
|
||||
// Define the single prompt according to the new structure
|
||||
let list_alerts_prompt = crate::mcp::protocol::PromptEntry {
|
||||
name: "list-wazuh-alerts".to_string(),
|
||||
description: Some("List the latest security alerts from Wazuh.".to_string()),
|
||||
arguments: vec![], // This prompt takes no arguments
|
||||
};
|
||||
|
||||
let prompts = vec![list_alerts_prompt];
|
||||
|
||||
let result = crate::mcp::protocol::PromptsListResult { prompts };
|
||||
|
||||
self.create_success_response(result, request.id)
|
||||
}
|
||||
|
||||
|
||||
fn create_success_response<T: serde::Serialize>(&self, result: T, id: Value) -> String {
|
||||
let response = JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
id,
|
||||
};
|
||||
|
||||
serde_json::to_string(&response).unwrap_or_else(|e| {
|
||||
error!("Failed to serialize JSON-RPC response: {}", e);
|
||||
format!(
|
||||
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize response"}},"id":null}}"#
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn create_error_response(
|
||||
&self,
|
||||
code: i32,
|
||||
message: String,
|
||||
data: Option<Value>,
|
||||
id: Value,
|
||||
) -> String {
|
||||
let response = JsonRpcResponse::<Value> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code,
|
||||
message,
|
||||
data,
|
||||
}),
|
||||
id,
|
||||
};
|
||||
|
||||
serde_json::to_string(&response).unwrap_or_else(|e| {
|
||||
error!("Failed to serialize JSON-RPC error response: {}", e);
|
||||
format!(
|
||||
r#"{{"jsonrpc":"2.0","error":{{"code":-32603,"message":"Internal error: Failed to serialize error response"}},"id":null}}"#
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
pub mod transform;
|
||||
pub mod client;
|
||||
pub mod protocol;
|
||||
pub mod mcp_server_core;
|
@ -1,127 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
pub params: Option<Value>,
|
||||
pub id: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct JsonRpcResponse<T: Serialize> {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<T>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
pub id: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema", skip_serializing_if = "Option::is_none")]
|
||||
pub input_schema: Option<Value>, // Added inputSchema
|
||||
#[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")]
|
||||
pub output_schema: Option<Value>, // Added outputSchema
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ToolCapability {
|
||||
pub supported: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub definitions: Vec<ToolDefinition>, // List available tools
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct SupportedFeature {
|
||||
pub supported: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct Capabilities {
|
||||
pub tools: ToolCapability, // Use the new structure
|
||||
pub resources: SupportedFeature,
|
||||
pub prompts: SupportedFeature,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct InitializeResult {
|
||||
#[serde(rename = "protocolVersion")]
|
||||
pub protocol_version: String,
|
||||
pub capabilities: Capabilities,
|
||||
#[serde(rename = "serverInfo")]
|
||||
pub server_info: ServerInfo,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ResourceEntry {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ResourcesListResult {
|
||||
pub resources: Vec<ResourceEntry>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ToolsListResult {
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
pub struct PromptArgument {
|
||||
pub name: String,
|
||||
pub required: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<Value>, // Use Value for flexibility (string, bool, number)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>, // Optional description for the argument
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
pub struct PromptEntry {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub arguments: Vec<PromptArgument>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct PromptsListResult {
|
||||
pub prompts: Vec<PromptEntry>,
|
||||
}
|
||||
|
||||
pub mod error_codes {
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
pub const SERVER_ERROR_START: i32 = -32000;
|
||||
pub const SERVER_ERROR_END: i32 = -32099;
|
||||
}
|
@ -1,235 +0,0 @@
|
||||
use chrono::{DateTime, Utc, SecondsFormat};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
pub fn transform_to_mcp(event: Value, event_type: String) -> Value {
|
||||
debug!(?event, %event_type, "Entering transform_to_mcp");
|
||||
|
||||
let source_obj = event.get("_source").unwrap_or(&event);
|
||||
if event.get("_source").is_some() {
|
||||
debug!("Event contains '_source' field, using it for transformation.");
|
||||
} else {
|
||||
debug!("Event does not contain '_source' field, using the event root for transformation.");
|
||||
}
|
||||
|
||||
let id = source_obj.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| event.get("_id").and_then(|v| v.as_str()))
|
||||
.unwrap_or("unknown_id")
|
||||
.to_string();
|
||||
debug!(%id, "Transformed: id");
|
||||
|
||||
let default_rule = json!({});
|
||||
let rule = source_obj.get("rule").unwrap_or(&default_rule);
|
||||
if source_obj.get("rule").is_none() {
|
||||
debug!("Transformed: rule (defaulted to empty object)");
|
||||
} else {
|
||||
debug!(?rule, "Transformed: rule");
|
||||
}
|
||||
let category = rule.get("groups")
|
||||
.and_then(|g| g.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown_category")
|
||||
.to_string();
|
||||
debug!(%category, "Transformed: category");
|
||||
|
||||
let severity_level = rule.get("level").and_then(|v| v.as_u64());
|
||||
let severity = severity_level
|
||||
.map(|level| match level {
|
||||
0..=3 => "low",
|
||||
4..=7 => "medium",
|
||||
8..=11 => "high",
|
||||
_ => "critical",
|
||||
})
|
||||
.unwrap_or("unknown_severity")
|
||||
.to_string();
|
||||
debug!(?severity_level, %severity, "Transformed: severity");
|
||||
|
||||
let description = rule.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
debug!(%description, "Transformed: description");
|
||||
|
||||
let default_data = json!({});
|
||||
let data = source_obj.get("data").cloned().unwrap_or_else(|| {
|
||||
debug!("Transformed: data (defaulted to empty object)");
|
||||
default_data.clone()
|
||||
});
|
||||
if source_obj.get("data").is_some() {
|
||||
debug!(?data, "Transformed: data");
|
||||
}
|
||||
|
||||
|
||||
let default_agent = json!({});
|
||||
let agent = source_obj.get("agent").cloned().unwrap_or_else(|| {
|
||||
debug!("Transformed: agent (defaulted to empty object)");
|
||||
default_agent.clone()
|
||||
});
|
||||
if source_obj.get("agent").is_some() {
|
||||
debug!(?agent, "Transformed: agent");
|
||||
}
|
||||
|
||||
|
||||
let timestamp_str = source_obj.get("timestamp")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
debug!(%timestamp_str, "Attempting to parse timestamp");
|
||||
|
||||
let timestamp = DateTime::parse_from_rfc3339(timestamp_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.or_else(|_| DateTime::parse_from_str(timestamp_str, "%Y-%m-%dT%H:%M:%S%.fZ").map(|dt| dt.with_timezone(&Utc)))
|
||||
.unwrap_or_else(|_| {
|
||||
warn!("Failed to parse timestamp '{}' for alert ID '{}'. Using current time.", timestamp_str, id);
|
||||
Utc::now()
|
||||
});
|
||||
debug!(%timestamp, "Transformed: timestamp");
|
||||
|
||||
let notes = "Data fetched via Wazuh API".to_string();
|
||||
debug!(%notes, "Transformed: notes");
|
||||
|
||||
let mcp_message = json!({
|
||||
"protocolVersion": "1.0", // Match initialize response
|
||||
"source": "Wazuh",
|
||||
"timestamp": timestamp.to_rfc3339_opts(SecondsFormat::Secs, true),
|
||||
"event_type": event_type,
|
||||
"context": {
|
||||
"id": id,
|
||||
"category": category,
|
||||
"severity": severity,
|
||||
"description": description,
|
||||
"agent": agent,
|
||||
"data": data
|
||||
},
|
||||
"metadata": {
|
||||
"integration": "Wazuh-MCP",
|
||||
"notes": notes
|
||||
}
|
||||
});
|
||||
debug!(?mcp_message, "Exiting transform_to_mcp with result");
|
||||
mcp_message
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_transform_to_mcp_basic() {
|
||||
let event_time_str = "2023-10-27T10:30:00.123Z";
|
||||
let event_time = Utc.datetime_from_str(event_time_str, "%Y-%m-%dT%H:%M:%S%.fZ").unwrap();
|
||||
|
||||
let event = json!({
|
||||
"id": "alert1",
|
||||
"_id": "wazuh_alert_id_1",
|
||||
"timestamp": event_time_str,
|
||||
"rule": {
|
||||
"level": 10,
|
||||
"description": "High severity rule triggered",
|
||||
"id": "1002",
|
||||
"groups": ["gdpr", "pci_dss", "intrusion_detection"]
|
||||
},
|
||||
"agent": {
|
||||
"id": "001",
|
||||
"name": "server-db"
|
||||
},
|
||||
"data": {
|
||||
"srcip": "1.2.3.4",
|
||||
"dstport": "22"
|
||||
}
|
||||
});
|
||||
|
||||
let result = transform_to_mcp(event.clone(), "alert".to_string());
|
||||
|
||||
assert_eq!(result["protocol_version"], "1.0");
|
||||
assert_eq!(result["source"], "Wazuh");
|
||||
assert_eq!(result["event_type"], "alert");
|
||||
assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true));
|
||||
|
||||
let context = &result["context"];
|
||||
assert_eq!(context["id"], "alert1");
|
||||
assert_eq!(context["category"], "gdpr");
|
||||
assert_eq!(context["severity"], "high");
|
||||
assert_eq!(context["description"], "High severity rule triggered");
|
||||
assert_eq!(context["agent"]["name"], "server-db");
|
||||
assert_eq!(context["data"]["srcip"], "1.2.3.4");
|
||||
|
||||
let metadata = &result["metadata"];
|
||||
assert_eq!(metadata["integration"], "Wazuh-MCP");
|
||||
assert_eq!(metadata["notes"], "Data fetched via Wazuh API");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_to_mcp_with_source_nesting() {
|
||||
let event_time_str = "2023-10-27T11:00:00Z";
|
||||
let event_time = DateTime::parse_from_rfc3339(event_time_str).unwrap().with_timezone(&Utc);
|
||||
|
||||
let event = json!({
|
||||
"_index": "wazuh-alerts-4.x-2023.10.27",
|
||||
"_id": "alert_source_nested",
|
||||
"_source": {
|
||||
"id": "nested_alert_id",
|
||||
"timestamp": event_time_str,
|
||||
"rule": {
|
||||
"level": 5,
|
||||
"description": "Medium severity rule",
|
||||
"groups": ["system_audit"]
|
||||
},
|
||||
"agent": { "id": "002", "name": "web-server" },
|
||||
"data": { "command": "useradd test" }
|
||||
}
|
||||
});
|
||||
|
||||
let result = transform_to_mcp(event.clone(), "alert".to_string());
|
||||
assert_eq!(result["timestamp"], event_time.to_rfc3339_opts(SecondsFormat::Secs, true));
|
||||
let context = &result["context"];
|
||||
assert_eq!(context["id"], "nested_alert_id");
|
||||
assert_eq!(context["category"], "system_audit");
|
||||
assert_eq!(context["severity"], "medium");
|
||||
assert_eq!(context["description"], "Medium severity rule");
|
||||
assert_eq!(context["agent"]["name"], "web-server");
|
||||
assert_eq!(context["data"]["command"], "useradd test");
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_transform_to_mcp_with_defaults() {
|
||||
let event = json!({});
|
||||
let before_transform = Utc::now();
|
||||
let result = transform_to_mcp(event, "alert".to_string());
|
||||
let after_transform = Utc::now();
|
||||
|
||||
assert_eq!(result["context"]["id"], "unknown_id");
|
||||
assert_eq!(result["context"]["category"], "unknown_category");
|
||||
assert_eq!(result["context"]["severity"], "unknown_severity");
|
||||
assert_eq!(result["context"]["description"], "");
|
||||
assert!(result["context"]["data"].is_object());
|
||||
assert!(result["context"]["agent"].is_object());
|
||||
assert_eq!(result["metadata"]["notes"], "Data fetched via Wazuh API");
|
||||
|
||||
let result_ts_str = result["timestamp"].as_str().unwrap();
|
||||
let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc);
|
||||
assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_timestamp_parsing_fallback() {
|
||||
let event = json!({
|
||||
"id": "ts_test",
|
||||
"timestamp": "invalid-timestamp-format",
|
||||
"rule": { "level": 3 },
|
||||
});
|
||||
let before_transform = Utc::now();
|
||||
let result = transform_to_mcp(event, "alert".to_string());
|
||||
let after_transform = Utc::now();
|
||||
|
||||
let result_ts_str = result["timestamp"].as_str().unwrap();
|
||||
let result_ts = DateTime::parse_from_rfc3339(result_ts_str).unwrap().with_timezone(&Utc);
|
||||
assert!(result_ts.timestamp() >= before_transform.timestamp() && result_ts.timestamp() <= after_transform.timestamp());
|
||||
assert_eq!(result["context"]["id"], "ts_test");
|
||||
assert_eq!(result["context"]["severity"], "low");
|
||||
}
|
||||
}
|
@ -1,186 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::sync::oneshot::Sender as OneshotSender;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::logging_utils::{log_mcp_request, log_mcp_response};
|
||||
use crate::mcp::mcp_server_core::McpServerCore;
|
||||
use crate::mcp::protocol::{error_codes, JsonRpcRequest};
|
||||
use crate::AppState;
|
||||
use serde_json::Value;
|
||||
|
||||
pub async fn run_stdio_service(app_state: Arc<AppState>, shutdown_tx: OneshotSender<()>) {
|
||||
info!("Starting MCP server in stdio mode...");
|
||||
let mut stdin_reader = BufReader::new(tokio::io::stdin());
|
||||
let mut stdout_writer = tokio::io::stdout();
|
||||
let mcp_core = McpServerCore::new(app_state);
|
||||
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
debug!("run_stdio_service: Initialized readers/writers. Entering main loop.");
|
||||
|
||||
loop {
|
||||
debug!("stdio_service: Top of the loop. Clearing line buffer.");
|
||||
line_buffer.clear();
|
||||
debug!("stdio_service: About to read_line from stdin.");
|
||||
|
||||
let read_result = stdin_reader.read_line(&mut line_buffer).await;
|
||||
debug!(?read_result, "stdio_service: read_line completed.");
|
||||
|
||||
match read_result {
|
||||
Ok(0) => {
|
||||
debug!("stdio_service: read_line returned Ok(0) (EOF).");
|
||||
info!("Stdin closed (EOF), signaling shutdown and exiting stdio mode.");
|
||||
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
|
||||
debug!("stdio_service read 0 bytes, breaking loop.");
|
||||
break; // EOF
|
||||
}
|
||||
Ok(bytes_read) => {
|
||||
debug!(%bytes_read, "stdio_service: read_line returned Ok(bytes_read).");
|
||||
let request_str = line_buffer.trim();
|
||||
if request_str.is_empty() {
|
||||
debug!("Received empty line from stdin, continuing.");
|
||||
continue;
|
||||
}
|
||||
info!("Received from stdin (stdio_service): {}", request_str);
|
||||
log_mcp_request(request_str); // Log the raw incoming string
|
||||
|
||||
let parsed_value: Value = match serde_json::from_str(request_str) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("JSON Parse Error: {}", e);
|
||||
let response_json = mcp_core.handle_parse_error(e, request_str);
|
||||
log_mcp_response(&response_json);
|
||||
info!("Sending parse error response to stdout: {}", response_json);
|
||||
let response_to_send = format!("{}\n", response_json);
|
||||
if let Err(write_err) =
|
||||
stdout_writer.write_all(response_to_send.as_bytes()).await
|
||||
{
|
||||
error!(
|
||||
"Error writing parse error response to stdout: {}",
|
||||
write_err
|
||||
);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
if let Err(flush_err) = stdout_writer.flush().await {
|
||||
error!("Error flushing stdout for parse error: {}", flush_err);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if parsed_value.get("id").is_none()
|
||||
|| parsed_value.get("id").map_or(false, |id| id.is_null())
|
||||
{
|
||||
// --- Handle Notification (No ID or ID is null) ---
|
||||
let method = parsed_value
|
||||
.get("method")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("");
|
||||
info!("Received Notification: method='{}'", method);
|
||||
|
||||
match method {
|
||||
"notifications/initialized" => {
|
||||
debug!("Client 'initialized' notification received. No action taken, no response sent.");
|
||||
}
|
||||
"exit" => {
|
||||
info!("'exit' notification received. Signaling shutdown immediately.");
|
||||
let _ = shutdown_tx.send(());
|
||||
return;
|
||||
}
|
||||
_ => {
|
||||
debug!(
|
||||
"Received unknown/unhandled notification method: '{}'. Ignoring.",
|
||||
method
|
||||
);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
let request_id = parsed_value.get("id").cloned().unwrap(); // We know ID exists and is not null here
|
||||
|
||||
match serde_json::from_value::<JsonRpcRequest>(parsed_value) {
|
||||
Ok(rpc_request) => {
|
||||
// --- Successfully parsed a Request ---
|
||||
let is_shutdown = rpc_request.method == "shutdown";
|
||||
let response_json = mcp_core.process_request(rpc_request).await;
|
||||
|
||||
// Log and send the response
|
||||
log_mcp_response(&response_json);
|
||||
info!("Sending response to stdout: {}", response_json);
|
||||
let response_to_send = format!("{}\n", response_json);
|
||||
|
||||
if let Err(e) =
|
||||
stdout_writer.write_all(response_to_send.as_bytes()).await
|
||||
{
|
||||
error!("Error writing response to stdout: {}", e);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout_writer.flush().await {
|
||||
error!("Error flushing stdout: {}", e);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle shutdown *after* sending the response
|
||||
if is_shutdown {
|
||||
debug!("'shutdown' request processed successfully. Signaling shutdown.");
|
||||
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
|
||||
return; // Exit the service loop
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Invalid JSON-RPC Request structure: {}", e);
|
||||
// Use the ID we extracted earlier
|
||||
let response_json = mcp_core.create_error_response(
|
||||
error_codes::INVALID_REQUEST,
|
||||
format!("Invalid Request structure: {}", e),
|
||||
None,
|
||||
request_id, // Use the ID from the original request
|
||||
);
|
||||
|
||||
log_mcp_response(&response_json);
|
||||
info!(
|
||||
"Sending invalid request error response to stdout: {}",
|
||||
response_json
|
||||
);
|
||||
let response_to_send = format!("{}\n", response_json);
|
||||
if let Err(write_err) =
|
||||
stdout_writer.write_all(response_to_send.as_bytes()).await
|
||||
{
|
||||
error!(
|
||||
"Error writing invalid request error response to stdout: {}",
|
||||
write_err
|
||||
);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
if let Err(flush_err) = stdout_writer.flush().await {
|
||||
error!(
|
||||
"Error flushing stdout for invalid request error: {}",
|
||||
flush_err
|
||||
);
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(error = %e, "stdio_service: read_line returned Err.");
|
||||
error!("Error reading from stdin for stdio_service: {}", e);
|
||||
debug!("Signaling shutdown and breaking loop due to stdin read error.");
|
||||
let _ = shutdown_tx.send(()); // Signal main to shutdown Axum
|
||||
break;
|
||||
}
|
||||
}
|
||||
debug!("stdio_service: Bottom of the loop, before next iteration.");
|
||||
}
|
||||
|
||||
info!("run_stdio_service: Exited main loop. stdio_service task is finishing.");
|
||||
}
|
563
src/tools/agents.rs
Normal file
563
src/tools/agents.rs
Normal file
@ -0,0 +1,563 @@
|
||||
//! Agent tools module for Wazuh MCP Server
|
||||
//!
|
||||
//! This module contains all agent-related tool implementations including:
|
||||
//! - Agent listing and filtering
|
||||
//! - Agent process monitoring
|
||||
//! - Agent network port monitoring
|
||||
|
||||
use super::{ToolModule, ToolUtils};
|
||||
use reqwest::StatusCode;
|
||||
use rmcp::model::{CallToolResult, Content};
|
||||
use rmcp::Error as McpError;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use wazuh_client::{AgentsClient, Port as WazuhPort, VulnerabilityClient};
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAgentsParams {
|
||||
#[schemars(description = "Maximum number of agents to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(
|
||||
description = "Agent status filter (active, disconnected, pending, never_connected)"
|
||||
)]
|
||||
pub status: String,
|
||||
#[schemars(description = "Agent name to search for (optional)")]
|
||||
pub name: Option<String>,
|
||||
#[schemars(description = "Agent IP address to filter by (optional)")]
|
||||
pub ip: Option<String>,
|
||||
#[schemars(description = "Agent group to filter by (optional)")]
|
||||
pub group: Option<String>,
|
||||
#[schemars(description = "Operating system platform to filter by (optional)")]
|
||||
pub os_platform: Option<String>,
|
||||
#[schemars(description = "Agent version to filter by (optional)")]
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAgentProcessesParams {
|
||||
#[schemars(
|
||||
description = "Agent ID to get processes for (required, e.g., \"0\", \"1\", \"001\")"
|
||||
)]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of processes to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Search string to filter processes by name or command (optional)")]
|
||||
pub search: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAgentPortsParams {
|
||||
#[schemars(
|
||||
description = "Agent ID to get network ports for (required, e.g., \"001\", \"002\", \"003\")"
|
||||
)]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of ports to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Protocol to filter by (e.g., \"tcp\", \"udp\")")]
|
||||
pub protocol: String,
|
||||
#[schemars(description = "State to filter by (e.g., \"LISTENING\", \"ESTABLISHED\")")]
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AgentTools {
|
||||
agents_client: Arc<Mutex<AgentsClient>>,
|
||||
vulnerability_client: Arc<Mutex<VulnerabilityClient>>,
|
||||
}
|
||||
|
||||
impl AgentTools {
|
||||
pub fn new(
|
||||
agents_client: Arc<Mutex<AgentsClient>>,
|
||||
vulnerability_client: Arc<Mutex<VulnerabilityClient>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents_client,
|
||||
vulnerability_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_agents(
|
||||
&self,
|
||||
params: GetAgentsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
status = ?params.status,
|
||||
name = ?params.name,
|
||||
ip = ?params.ip,
|
||||
group = ?params.group,
|
||||
os_platform = ?params.os_platform,
|
||||
version = ?params.version,
|
||||
"Retrieving Wazuh agents"
|
||||
);
|
||||
|
||||
let mut agents_client = self.agents_client.lock().await;
|
||||
|
||||
match agents_client
|
||||
.get_agents(
|
||||
Some(limit),
|
||||
None, // offset
|
||||
None, // select
|
||||
None, // sort
|
||||
None, // search
|
||||
Some(¶ms.status),
|
||||
None, // query
|
||||
None, // older_than
|
||||
params.os_platform.as_deref(),
|
||||
None, // os_version
|
||||
None, // os_name
|
||||
None, // manager_host
|
||||
params.version.as_deref(),
|
||||
params.group.as_deref(),
|
||||
None, // node_name
|
||||
params.name.as_deref(),
|
||||
params.ip.as_deref(),
|
||||
None, // register_ip
|
||||
None, // group_config_status
|
||||
None, // distinct
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(agents) => {
|
||||
if agents.is_empty() {
|
||||
tracing::info!(
|
||||
"No Wazuh agents found matching criteria. Returning standard message."
|
||||
);
|
||||
return Self::not_found_result(&format!(
|
||||
"Wazuh agents matching the specified criteria (status: {})",
|
||||
¶ms.status
|
||||
));
|
||||
}
|
||||
|
||||
let num_agents = agents.len();
|
||||
let mcp_content_items: Vec<Content> = agents
|
||||
.into_iter()
|
||||
.map(|agent| {
|
||||
let status_indicator = match agent.status.to_lowercase().as_str() {
|
||||
"active" => "🟢 ACTIVE",
|
||||
"disconnected" => "🔴 DISCONNECTED",
|
||||
"pending" => "🟡 PENDING",
|
||||
"never_connected" => "⚪ NEVER CONNECTED",
|
||||
_ => &agent.status,
|
||||
};
|
||||
|
||||
let ip_info = if let Some(ip) = &agent.ip {
|
||||
format!("\nIP: {}", ip)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let register_ip_info = if let Some(register_ip) = &agent.register_ip {
|
||||
if agent.ip.as_ref() != Some(register_ip) {
|
||||
format!("\nRegistered IP: {}", register_ip)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let os_info = if let Some(os) = &agent.os {
|
||||
let mut os_parts = Vec::new();
|
||||
if let Some(name) = &os.name {
|
||||
os_parts.push(name.clone());
|
||||
}
|
||||
if let Some(version) = &os.version {
|
||||
os_parts.push(version.clone());
|
||||
}
|
||||
if let Some(arch) = &os.arch {
|
||||
os_parts.push(format!("({})", arch));
|
||||
}
|
||||
if !os_parts.is_empty() {
|
||||
format!("\nOS: {}", os_parts.join(" "))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let version_info = if let Some(version) = &agent.version {
|
||||
format!("\nAgent Version: {}", version)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let group_info = if let Some(groups) = &agent.group {
|
||||
if !groups.is_empty() {
|
||||
format!("\nGroups: {}", groups.join(", "))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let last_keep_alive_info =
|
||||
if let Some(last_keep_alive) = &agent.last_keep_alive {
|
||||
format!("\nLast Keep Alive: {}", last_keep_alive)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let date_add_info = if let Some(date_add) = &agent.date_add {
|
||||
format!("\nRegistered: {}", date_add)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let node_info = if let Some(node_name) = &agent.node_name {
|
||||
format!("\nNode: {}", node_name)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let config_status_info =
|
||||
if let Some(config_status) = &agent.group_config_status {
|
||||
let config_indicator = match config_status.to_lowercase().as_str() {
|
||||
"synced" => "✅ SYNCED",
|
||||
"not synced" => "❌ NOT SYNCED",
|
||||
_ => config_status,
|
||||
};
|
||||
format!("\nConfig Status: {}", config_indicator)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let agent_id_display = if agent.id == "000" {
|
||||
format!("{} (Wazuh Manager)", agent.id)
|
||||
} else {
|
||||
agent.id.clone()
|
||||
};
|
||||
|
||||
let formatted_text = format!(
|
||||
"Agent ID: {}\nName: {}\nStatus: {}{}{}{}{}{}{}{}{}{}",
|
||||
agent_id_display,
|
||||
agent.name,
|
||||
status_indicator,
|
||||
ip_info,
|
||||
register_ip_info,
|
||||
os_info,
|
||||
version_info,
|
||||
group_info,
|
||||
last_keep_alive_info,
|
||||
date_add_info,
|
||||
node_info,
|
||||
config_status_info
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} agents into {} MCP content items",
|
||||
num_agents,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("agents", "retrieving agents", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_agent_processes(
|
||||
&self,
|
||||
params: GetAgentProcessesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let agent_id = match ToolUtils::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
tracing::error!("Error formatting agent_id for agent processes: {}", err_msg);
|
||||
return Self::error_result(err_msg);
|
||||
}
|
||||
};
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = 0;
|
||||
|
||||
tracing::info!(
|
||||
agent_id = %agent_id,
|
||||
limit = %limit,
|
||||
search = ?params.search,
|
||||
"Retrieving Wazuh agent processes"
|
||||
);
|
||||
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
match vulnerability_client
|
||||
.get_agent_processes(
|
||||
&agent_id,
|
||||
Some(limit),
|
||||
Some(offset),
|
||||
params.search.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(processes) => {
|
||||
if processes.is_empty() {
|
||||
tracing::info!("No processes found for agent {} with current filters. Returning standard message.", agent_id);
|
||||
return Self::not_found_result(&format!(
|
||||
"processes for agent {} matching the specified criteria",
|
||||
agent_id
|
||||
));
|
||||
}
|
||||
|
||||
let num_processes = processes.len();
|
||||
let mcp_content_items: Vec<Content> = processes
|
||||
.into_iter()
|
||||
.map(|process| {
|
||||
let mut details = vec![
|
||||
format!("PID: {}", process.pid),
|
||||
format!("Name: {}", process.name),
|
||||
];
|
||||
|
||||
if let Some(state) = &process.state {
|
||||
details.push(format!("State: {}", state));
|
||||
}
|
||||
if let Some(ppid) = &process.ppid {
|
||||
details.push(format!("PPID: {}", ppid));
|
||||
}
|
||||
if let Some(euser) = &process.euser {
|
||||
details.push(format!("User: {}", euser));
|
||||
}
|
||||
if let Some(cmd) = &process.cmd {
|
||||
details.push(format!("Command: {}", cmd));
|
||||
}
|
||||
if let Some(start_time_str) = &process.start_time {
|
||||
if let Ok(start_time_unix) = start_time_str.parse::<i64>() {
|
||||
// Assuming start_time is a Unix timestamp in seconds
|
||||
use chrono::DateTime;
|
||||
if let Some(dt) = DateTime::from_timestamp(start_time_unix, 0) {
|
||||
details.push(format!(
|
||||
"Start Time: {}",
|
||||
dt.format("%Y-%m-%d %H:%M:%S UTC")
|
||||
));
|
||||
} else {
|
||||
details.push(format!("Start Time: {} (raw)", start_time_str));
|
||||
}
|
||||
} else {
|
||||
// If it's not a simple number, print as is
|
||||
details.push(format!("Start Time: {}", start_time_str));
|
||||
}
|
||||
}
|
||||
if let Some(resident_mem) = process.resident {
|
||||
details.push(format!("Memory (Resident): {} KB", resident_mem / 1024));
|
||||
// Assuming resident is in bytes
|
||||
}
|
||||
if let Some(vm_size) = process.vm_size {
|
||||
details.push(format!("Memory (VM Size): {} KB", vm_size / 1024));
|
||||
// Assuming vm_size is in bytes
|
||||
}
|
||||
|
||||
Content::text(details.join("\n"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} processes for agent {} into {} MCP content items",
|
||||
num_processes,
|
||||
agent_id,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => match e {
|
||||
wazuh_client::WazuhApiError::HttpError {
|
||||
status,
|
||||
message: _,
|
||||
url: _,
|
||||
} if status == StatusCode::NOT_FOUND => {
|
||||
tracing::info!("No process data found for agent {}. Syscollector might not have run or data is unavailable.", agent_id);
|
||||
Self::success_result(vec![Content::text(
|
||||
format!("No process data found for agent {}. The agent might not exist, syscollector data might be unavailable, or the agent is not active.", agent_id),
|
||||
)])
|
||||
}
|
||||
_ => {
|
||||
let err_msg = Self::format_error(
|
||||
"agent processes",
|
||||
&format!("retrieving processes for agent {}", agent_id),
|
||||
&e,
|
||||
);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_agent_ports(
|
||||
&self,
|
||||
params: GetAgentPortsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let agent_id = match ToolUtils::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
tracing::error!("Error formatting agent_id for agent ports: {}", err_msg);
|
||||
return Self::error_result(err_msg);
|
||||
}
|
||||
};
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = 0; // Default offset
|
||||
|
||||
tracing::info!(
|
||||
agent_id = %agent_id,
|
||||
limit = %limit,
|
||||
protocol = ?params.protocol,
|
||||
state = ?params.state,
|
||||
"Retrieving Wazuh agent network ports"
|
||||
);
|
||||
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
// Note: The wazuh_client::VulnerabilityClient::get_agent_ports provided in the prompt
|
||||
// only supports filtering by protocol. If state filtering is needed, the client would need an update.
|
||||
// For now, we pass params.protocol and ignore params.state for the API call,
|
||||
// but we can filter by state client-side if necessary, or acknowledge this limitation.
|
||||
// The current wazuh-client `get_agent_ports` does not support state filtering directly in its parameters.
|
||||
// We will filter client-side for now if `params.state` is provided.
|
||||
match vulnerability_client
|
||||
.get_agent_ports(
|
||||
&agent_id,
|
||||
Some(limit * 2), // Fetch more to allow for client-side state filtering
|
||||
Some(offset),
|
||||
Some(¶ms.protocol),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut ports) => {
|
||||
let requested_state_is_listening =
|
||||
params.state.trim().eq_ignore_ascii_case("listening");
|
||||
|
||||
ports.retain(|port| {
|
||||
tracing::debug!(
|
||||
"Pre-filter port: {:?} (State: {:?}), requested_state_is_listening: {}",
|
||||
port.inode, // Using inode for a concise port identifier in log
|
||||
port.state,
|
||||
requested_state_is_listening
|
||||
);
|
||||
let result = match port.state.as_ref().map(|s| s.trim()) {
|
||||
Some(actual_port_state_str) => {
|
||||
// Port has a state string
|
||||
if actual_port_state_str.is_empty() {
|
||||
// Filter out ports where state is present but an empty string
|
||||
false
|
||||
} else if requested_state_is_listening {
|
||||
// User requested "listening": keep only if actual state is "listening"
|
||||
actual_port_state_str.eq_ignore_ascii_case("listening")
|
||||
} else {
|
||||
// User requested non-"listening": keep if actual state is not "listening"
|
||||
!actual_port_state_str.eq_ignore_ascii_case("listening")
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Port has no state (port.state is None)
|
||||
if requested_state_is_listening {
|
||||
// If user wants "listening" ports, a port with no state is not a match.
|
||||
false
|
||||
} else {
|
||||
// If user wants non-"listening" ports, a port with no state is a match.
|
||||
true
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"Post-filter decision for port: {:?}, Keep: {}",
|
||||
port.inode,
|
||||
result
|
||||
);
|
||||
result
|
||||
});
|
||||
|
||||
// Apply limit after client-side filtering
|
||||
ports.truncate(limit as usize);
|
||||
|
||||
if ports.is_empty() {
|
||||
tracing::info!("No network ports found for agent {} with current filters. Returning standard message.", agent_id);
|
||||
return Self::not_found_result(&format!(
|
||||
"network ports for agent {} matching the specified criteria",
|
||||
agent_id
|
||||
));
|
||||
}
|
||||
|
||||
let num_ports = ports.len();
|
||||
let mcp_content_items: Vec<Content> = ports
|
||||
.into_iter()
|
||||
.map(|port: WazuhPort| {
|
||||
// Explicitly type port
|
||||
let mut details = vec![
|
||||
format!("Protocol: {}", port.protocol),
|
||||
format!(
|
||||
"Local: {}:{}",
|
||||
port.local.ip.clone().unwrap_or("N/A".to_string()),
|
||||
port.local.port
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(remote) = &port.remote {
|
||||
details.push(format!(
|
||||
"Remote: {}:{}",
|
||||
remote.ip.clone().unwrap_or("N/A".to_string()),
|
||||
remote.port
|
||||
));
|
||||
}
|
||||
if let Some(state) = &port.state {
|
||||
details.push(format!("State: {}", state));
|
||||
}
|
||||
if let Some(process_name) = &port.process {
|
||||
// process field in WazuhPort is Option<String>
|
||||
details.push(format!("Process Name: {}", process_name));
|
||||
}
|
||||
if let Some(pid) = port.pid {
|
||||
// pid field in WazuhPort is Option<u32>
|
||||
details.push(format!("PID: {}", pid));
|
||||
}
|
||||
if let Some(inode) = port.inode {
|
||||
details.push(format!("Inode: {}", inode));
|
||||
}
|
||||
if let Some(tx_queue) = port.tx_queue {
|
||||
details.push(format!("TX Queue: {}", tx_queue));
|
||||
}
|
||||
if let Some(rx_queue) = port.rx_queue {
|
||||
details.push(format!("RX Queue: {}", rx_queue));
|
||||
}
|
||||
|
||||
Content::text(details.join("\n"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("Successfully processed {} network ports for agent {} into {} MCP content items", num_ports, agent_id, mcp_content_items.len());
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => match e {
|
||||
wazuh_client::WazuhApiError::HttpError {
|
||||
status,
|
||||
message: _,
|
||||
url: _,
|
||||
} if status == StatusCode::NOT_FOUND => {
|
||||
tracing::info!("No network port data found for agent {}. Syscollector might not have run or data is unavailable.", agent_id);
|
||||
Self::success_result(vec![Content::text(
|
||||
format!("No network port data found for agent {}. The agent might not exist, syscollector data might be unavailable, or the agent is not active.", agent_id),
|
||||
)])
|
||||
}
|
||||
_ => {
|
||||
let err_msg = Self::format_error(
|
||||
"agent ports",
|
||||
&format!("retrieving network ports for agent {}", agent_id),
|
||||
&e,
|
||||
);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for AgentTools {}
|
||||
|
103
src/tools/alerts.rs
Normal file
103
src/tools/alerts.rs
Normal file
@ -0,0 +1,103 @@
|
||||
//! Wazuh Indexer alert tools
|
||||
//!
|
||||
//! This module contains tools for retrieving and analyzing Wazuh security alerts
|
||||
//! from the Wazuh Indexer.
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError,
|
||||
model::{CallToolResult, Content},
|
||||
tool,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use wazuh_client::WazuhIndexerClient;
|
||||
use super::ToolModule;
|
||||
|
||||
/// Parameters for getting alert summary
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetAlertSummaryParams {
|
||||
#[schemars(description = "Maximum number of alerts to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
/// Alert tools implementation
|
||||
#[derive(Clone)]
|
||||
pub struct AlertTools {
|
||||
indexer_client: Arc<WazuhIndexerClient>,
|
||||
}
|
||||
|
||||
impl AlertTools {
|
||||
pub fn new(indexer_client: Arc<WazuhIndexerClient>) -> Self {
|
||||
Self { indexer_client }
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_alert_summary",
|
||||
description = "Retrieves a summary of Wazuh security alerts. Returns formatted alert information including ID, timestamp, and description."
|
||||
)]
|
||||
pub async fn get_wazuh_alert_summary(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetAlertSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(limit = %limit, "Retrieving Wazuh alert summary");
|
||||
|
||||
match self.indexer_client.get_alerts(Some(limit)).await {
|
||||
Ok(raw_alerts) => {
|
||||
if raw_alerts.is_empty() {
|
||||
tracing::info!("No Wazuh alerts found to process. Returning standard message.");
|
||||
return Self::not_found_result("Wazuh alerts");
|
||||
}
|
||||
|
||||
let num_alerts_to_process = raw_alerts.len();
|
||||
let mcp_content_items: Vec<Content> = raw_alerts
|
||||
.into_iter()
|
||||
.map(|alert_value| {
|
||||
let source = alert_value.get("_source").unwrap_or(&alert_value);
|
||||
|
||||
let id = source.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| alert_value.get("_id").and_then(|v| v.as_str()))
|
||||
.unwrap_or("Unknown ID");
|
||||
|
||||
let description = source.get("rule")
|
||||
.and_then(|r| r.get("description"))
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("No description available");
|
||||
|
||||
let timestamp = source.get("timestamp")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("Unknown time");
|
||||
|
||||
let agent_name = source.get("agent")
|
||||
.and_then(|a| a.get("name"))
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("Unknown agent");
|
||||
|
||||
let rule_level = source.get("rule")
|
||||
.and_then(|r| r.get("level"))
|
||||
.and_then(|l| l.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let formatted_text = format!(
|
||||
"Alert ID: {}\nTime: {}\nAgent: {}\nLevel: {}\nDescription: {}",
|
||||
id, timestamp, agent_name, rule_level, description
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("Successfully processed {} alerts into {} MCP content items", num_alerts_to_process, mcp_content_items.len());
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Indexer", "retrieving alerts", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for AlertTools {}
|
||||
|
63
src/tools/mod.rs
Normal file
63
src/tools/mod.rs
Normal file
@ -0,0 +1,63 @@
|
||||
//! Tools module for Wazuh MCP Server
|
||||
//!
|
||||
//! This module contains all the tool implementations organized by Wazuh component domains.
|
||||
//! Each submodule handles a specific area of Wazuh functionality.
|
||||
|
||||
pub mod agents;
|
||||
pub mod alerts;
|
||||
pub mod rules;
|
||||
pub mod stats;
|
||||
pub mod vulnerabilities;
|
||||
|
||||
use rmcp::model::{CallToolResult, Content};
|
||||
use rmcp::Error as McpError;
|
||||
|
||||
pub trait ToolModule {
|
||||
fn format_error(component: &str, operation: &str, error: &dyn std::fmt::Display) -> String {
|
||||
format!("Error {} from Wazuh {}: {}", operation, component, error)
|
||||
}
|
||||
|
||||
fn success_result(content: Vec<Content>) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::success(content))
|
||||
}
|
||||
|
||||
fn error_result(message: String) -> Result<CallToolResult, McpError> {
|
||||
Ok(CallToolResult::error(vec![Content::text(message)]))
|
||||
}
|
||||
|
||||
fn not_found_result(resource: &str) -> Result<CallToolResult, McpError> {
|
||||
let message = if resource == "Wazuh alerts" {
|
||||
"No Wazuh alerts found.".to_string()
|
||||
} else {
|
||||
format!("No {} found matching the specified criteria.", resource)
|
||||
};
|
||||
Ok(CallToolResult::success(vec![Content::text(message)]))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolUtils;
|
||||
|
||||
impl ToolUtils {
|
||||
pub fn format_agent_id(agent_id_str: &str) -> Result<String, String> {
|
||||
// Attempt to parse as a number first
|
||||
if let Ok(num) = agent_id_str.parse::<u32>() {
|
||||
if num > 999 {
|
||||
Err(format!(
|
||||
"Agent ID '{}' is too large. Must be a number between 0 and 999.",
|
||||
agent_id_str
|
||||
))
|
||||
} else {
|
||||
Ok(format!("{:03}", num))
|
||||
}
|
||||
} else if agent_id_str.len() == 3 && agent_id_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
// Already correctly formatted (e.g., "001")
|
||||
Ok(agent_id_str.to_string())
|
||||
} else {
|
||||
Err(format!(
|
||||
"Invalid agent_id format: '{}'. Must be a number (e.g., 1, 12) or a 3-digit string (e.g., 001, 012).",
|
||||
agent_id_str
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
142
src/tools/rules.rs
Normal file
142
src/tools/rules.rs
Normal file
@ -0,0 +1,142 @@
|
||||
//! Wazuh Manager rule tools
|
||||
//!
|
||||
//! This module contains tools for retrieving and analyzing Wazuh security rules
|
||||
//! from the Wazuh Manager.
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError,
|
||||
model::{CallToolResult, Content},
|
||||
tool,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use wazuh_client::RulesClient;
|
||||
use super::ToolModule;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetRulesSummaryParams {
|
||||
#[schemars(description = "Maximum number of rules to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Rule level to filter by (optional)")]
|
||||
pub level: Option<u32>,
|
||||
#[schemars(description = "Rule group to filter by (optional)")]
|
||||
pub group: Option<String>,
|
||||
#[schemars(description = "Filename to filter by (optional)")]
|
||||
pub filename: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RuleTools {
|
||||
rules_client: Arc<Mutex<RulesClient>>,
|
||||
}
|
||||
|
||||
impl RuleTools {
|
||||
pub fn new(rules_client: Arc<Mutex<RulesClient>>) -> Self {
|
||||
Self { rules_client }
|
||||
}
|
||||
|
||||
#[tool(
|
||||
name = "get_wazuh_rules_summary",
|
||||
description = "Retrieves a summary of Wazuh security rules. Returns formatted rule information including ID, level, description, and groups. Supports filtering by level, group, and filename."
|
||||
)]
|
||||
pub async fn get_wazuh_rules_summary(
|
||||
&self,
|
||||
#[tool(aggr)] params: GetRulesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
level = ?params.level,
|
||||
group = ?params.group,
|
||||
filename = ?params.filename,
|
||||
"Retrieving Wazuh rules summary"
|
||||
);
|
||||
|
||||
let mut rules_client = self.rules_client.lock().await;
|
||||
|
||||
match rules_client.get_rules(
|
||||
Some(limit),
|
||||
None, // offset
|
||||
params.level,
|
||||
params.group.as_deref(),
|
||||
params.filename.as_deref(),
|
||||
).await {
|
||||
Ok(rules) => {
|
||||
if rules.is_empty() {
|
||||
tracing::info!("No Wazuh rules found matching criteria. Returning standard message.");
|
||||
return Self::not_found_result("Wazuh rules matching the specified criteria");
|
||||
}
|
||||
|
||||
let num_rules = rules.len();
|
||||
let mcp_content_items: Vec<Content> = rules
|
||||
.into_iter()
|
||||
.map(|rule| {
|
||||
let groups_str = rule.groups.join(", ");
|
||||
|
||||
let compliance_info = {
|
||||
let mut compliance = Vec::new();
|
||||
if let Some(gdpr) = &rule.gdpr {
|
||||
if !gdpr.is_empty() {
|
||||
compliance.push(format!("GDPR: {}", gdpr.join(", ")));
|
||||
}
|
||||
}
|
||||
if let Some(hipaa) = &rule.hipaa {
|
||||
if !hipaa.is_empty() {
|
||||
compliance.push(format!("HIPAA: {}", hipaa.join(", ")));
|
||||
}
|
||||
}
|
||||
if let Some(pci) = &rule.pci_dss {
|
||||
if !pci.is_empty() {
|
||||
compliance.push(format!("PCI DSS: {}", pci.join(", ")));
|
||||
}
|
||||
}
|
||||
if let Some(nist) = &rule.nist_800_53 {
|
||||
if !nist.is_empty() {
|
||||
compliance.push(format!("NIST 800-53: {}", nist.join(", ")));
|
||||
}
|
||||
}
|
||||
if compliance.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\nCompliance: {}", compliance.join(" | "))
|
||||
}
|
||||
};
|
||||
|
||||
let severity = match rule.level {
|
||||
0..=3 => "Low",
|
||||
4..=7 => "Medium",
|
||||
8..=12 => "High",
|
||||
13..=15 => "Critical",
|
||||
_ => "Unknown",
|
||||
};
|
||||
|
||||
let formatted_text = format!(
|
||||
"Rule ID: {}\nLevel: {} ({})\nDescription: {}\nGroups: {}\nFile: {}\nStatus: {}{}",
|
||||
rule.id,
|
||||
rule.level,
|
||||
severity,
|
||||
rule.description,
|
||||
groups_str,
|
||||
rule.filename,
|
||||
rule.status,
|
||||
compliance_info
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("Successfully processed {} rules into {} MCP content items", num_rules, mcp_content_items.len());
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Manager", "retrieving rules", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for RuleTools {}
|
||||
|
493
src/tools/stats.rs
Normal file
493
src/tools/stats.rs
Normal file
@ -0,0 +1,493 @@
|
||||
//! Stats tools for Wazuh MCP Server
|
||||
//!
|
||||
//! This module contains tools for retrieving various statistics from Wazuh components,
|
||||
//! including manager logs, remoted daemon stats, log collector stats, and weekly statistics.
|
||||
|
||||
use super::{ToolModule, ToolUtils};
|
||||
use reqwest::StatusCode;
|
||||
use rmcp::model::{CallToolResult, Content};
|
||||
use rmcp::Error as McpError;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use wazuh_client::{ClusterClient, LogsClient};
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct SearchManagerLogsParams {
|
||||
#[schemars(description = "Maximum number of log entries to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Number of log entries to skip (default: 0)")]
|
||||
pub offset: Option<u32>,
|
||||
#[schemars(description = "Log level to filter by (e.g., \"error\", \"warning\", \"info\")")]
|
||||
pub level: String,
|
||||
#[schemars(description = "Log tag to filter by (e.g., \"wazuh-modulesd\") (optional)")]
|
||||
pub tag: Option<String>,
|
||||
#[schemars(description = "Search term to filter log descriptions (optional)")]
|
||||
pub search_term: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetManagerErrorLogsParams {
|
||||
#[schemars(description = "Maximum number of error log entries to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetLogCollectorStatsParams {
|
||||
#[schemars(
|
||||
description = "Agent ID to get log collector stats for (required, e.g., \"0\", \"1\", \"001\")"
|
||||
)]
|
||||
pub agent_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetRemotedStatsParams {
|
||||
// No parameters needed for remoted stats
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetWeeklyStatsParams {
|
||||
// No parameters needed for weekly stats
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetClusterHealthParams {
|
||||
// No parameters needed for cluster health
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetClusterNodesParams {
|
||||
#[schemars(
|
||||
description = "Maximum number of nodes to retrieve (optional, Wazuh API default is 500)"
|
||||
)]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Number of nodes to skip (offset) (optional, default: 0)")]
|
||||
pub offset: Option<u32>,
|
||||
#[schemars(description = "Filter by node type (e.g., 'master', 'worker') (optional)")]
|
||||
pub node_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StatsTools {
|
||||
logs_client: Arc<Mutex<LogsClient>>,
|
||||
cluster_client: Arc<Mutex<ClusterClient>>,
|
||||
}
|
||||
|
||||
impl StatsTools {
|
||||
pub fn new(
|
||||
logs_client: Arc<Mutex<LogsClient>>,
|
||||
cluster_client: Arc<Mutex<ClusterClient>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
logs_client,
|
||||
cluster_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search_wazuh_manager_logs(
|
||||
&self,
|
||||
params: SearchManagerLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
offset = %offset,
|
||||
level = ?params.level,
|
||||
tag = ?params.tag,
|
||||
search_term = ?params.search_term,
|
||||
"Searching Wazuh manager logs"
|
||||
);
|
||||
|
||||
let mut logs_client = self.logs_client.lock().await;
|
||||
|
||||
match logs_client
|
||||
.get_manager_logs(
|
||||
Some(limit),
|
||||
Some(offset),
|
||||
Some(¶ms.level),
|
||||
params.tag.as_deref(),
|
||||
params.search_term.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(log_entries) => {
|
||||
if log_entries.is_empty() {
|
||||
tracing::info!("No Wazuh manager logs found matching criteria. Returning standard message.");
|
||||
return Self::not_found_result("Wazuh manager logs");
|
||||
}
|
||||
|
||||
let num_logs = log_entries.len();
|
||||
let mcp_content_items: Vec<Content> = log_entries
|
||||
.into_iter()
|
||||
.map(|log_entry| {
|
||||
let formatted_text = format!(
|
||||
"Timestamp: {}\nTag: {}\nLevel: {}\nDescription: {}",
|
||||
log_entry.timestamp,
|
||||
log_entry.tag,
|
||||
log_entry.level,
|
||||
log_entry.description
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} manager log entries into {} MCP content items",
|
||||
num_logs,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Manager", "searching logs", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_manager_error_logs(
|
||||
&self,
|
||||
params: GetManagerErrorLogsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
|
||||
tracing::info!(limit = %limit, "Retrieving Wazuh manager error logs");
|
||||
|
||||
let mut logs_client = self.logs_client.lock().await;
|
||||
|
||||
match logs_client.get_error_logs(Some(limit)).await {
|
||||
Ok(log_entries) => {
|
||||
if log_entries.is_empty() {
|
||||
tracing::info!(
|
||||
"No Wazuh manager error logs found. Returning standard message."
|
||||
);
|
||||
return Self::not_found_result("Wazuh manager error logs");
|
||||
}
|
||||
|
||||
let num_logs = log_entries.len();
|
||||
let mcp_content_items: Vec<Content> = log_entries
|
||||
.into_iter()
|
||||
.map(|log_entry| {
|
||||
let formatted_text = format!(
|
||||
"Timestamp: {}\nTag: {}\nLevel: {}\nDescription: {}",
|
||||
log_entry.timestamp,
|
||||
log_entry.tag,
|
||||
log_entry.level,
|
||||
log_entry.description
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} manager error log entries into {} MCP content items",
|
||||
num_logs,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Manager", "retrieving error logs", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_log_collector_stats(
|
||||
&self,
|
||||
params: GetLogCollectorStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let agent_id = match ToolUtils::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
tracing::error!(
|
||||
"Error formatting agent_id for log collector stats: {}",
|
||||
err_msg
|
||||
);
|
||||
return Self::error_result(err_msg);
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(agent_id = %agent_id, "Retrieving Wazuh log collector stats");
|
||||
|
||||
let mut logs_client = self.logs_client.lock().await;
|
||||
|
||||
match logs_client.get_logcollector_stats(&agent_id).await {
|
||||
Ok(stats) => {
|
||||
// Helper closure to format a LogCollectorPeriod
|
||||
let format_period = |period_name: &str,
|
||||
period_data: &wazuh_client::logs::LogCollectorPeriod|
|
||||
-> String {
|
||||
let files_info: String = period_data
|
||||
.files
|
||||
.iter()
|
||||
.map(|file: &wazuh_client::logs::LogFile| {
|
||||
let targets_str: String = file
|
||||
.targets
|
||||
.iter()
|
||||
.map(|target: &wazuh_client::logs::LogTarget| {
|
||||
format!(
|
||||
" - Name: {}, Drops: {}",
|
||||
target.name, target.drops
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
let targets_display = if targets_str.is_empty() {
|
||||
" (No specific targets with drops for this file)".to_string()
|
||||
} else {
|
||||
format!(" Targets:\n{}", targets_str)
|
||||
};
|
||||
format!(
|
||||
" - Location: {}\n Events: {}\n Bytes: {}\n{}",
|
||||
file.location, file.events, file.bytes, targets_display
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n\n");
|
||||
|
||||
let files_display = if files_info.is_empty() {
|
||||
" (No files processed in this period)".to_string()
|
||||
} else {
|
||||
files_info
|
||||
};
|
||||
|
||||
format!(
|
||||
"{}:\n Start: {}\n End: {}\n Files:\n{}",
|
||||
period_name, period_data.start, period_data.end, files_display
|
||||
)
|
||||
};
|
||||
|
||||
let global_period_info = format_period("Global Period", &stats.global);
|
||||
let interval_period_info = format_period("Interval Period", &stats.interval);
|
||||
|
||||
let formatted_text = format!(
|
||||
"Log Collector Stats for Agent: {}\n\n{}\n\n{}",
|
||||
agent_id,
|
||||
global_period_info,
|
||||
interval_period_info
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
"Successfully retrieved and formatted log collector stats for agent {}",
|
||||
agent_id
|
||||
);
|
||||
Self::success_result(vec![Content::text(formatted_text)])
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if the error is due to agent not found or stats not available
|
||||
if let wazuh_client::WazuhApiError::ApiError(msg) = &e {
|
||||
if msg.contains(&format!(
|
||||
"Log collector stats for agent {} not found",
|
||||
agent_id
|
||||
)) {
|
||||
tracing::info!("No log collector stats found for agent {} (API error). Returning standard message.", agent_id);
|
||||
return Self::success_result(vec![Content::text(
|
||||
format!("No log collector stats found for agent {}. The agent might not exist, stats are unavailable, or the agent is not active.", agent_id),
|
||||
)]);
|
||||
}
|
||||
if msg.contains("Agent Not Found") {
|
||||
tracing::info!(
|
||||
"Agent {} not found (API error). Returning standard message.",
|
||||
agent_id
|
||||
);
|
||||
return Self::success_result(vec![Content::text(format!(
|
||||
"Agent {} not found. Cannot retrieve log collector stats.",
|
||||
agent_id
|
||||
))]);
|
||||
}
|
||||
}
|
||||
if let wazuh_client::WazuhApiError::HttpError { status, .. } = &e {
|
||||
if *status == StatusCode::NOT_FOUND {
|
||||
tracing::info!("No log collector stats found for agent {} (HTTP 404). Agent might not exist or endpoint unavailable.", agent_id);
|
||||
return Self::success_result(vec![Content::text(
|
||||
format!("No log collector stats found for agent {}. The agent might not exist, stats are unavailable, or the agent is not active.", agent_id),
|
||||
)]);
|
||||
}
|
||||
}
|
||||
let err_msg = format!(
|
||||
"Error retrieving log collector stats for agent {} from Wazuh: {}",
|
||||
agent_id, e
|
||||
);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_remoted_stats(
|
||||
&self,
|
||||
_params: GetRemotedStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
tracing::info!("Retrieving Wazuh remoted stats");
|
||||
|
||||
let mut logs_client = self.logs_client.lock().await;
|
||||
|
||||
match logs_client.get_remoted_stats().await {
|
||||
Ok(stats) => {
|
||||
let formatted_text = format!(
|
||||
"Wazuh Remoted Statistics:\nQueue Size: {}\nTotal Queue Size: {}\nTCP Sessions: {}\nControl Message Count: {}\nDiscarded Message Count: {}\nMessages Sent (Bytes): {}\nBytes Received: {}\nDequeued After Close: {}",
|
||||
stats.queue_size,
|
||||
stats.total_queue_size,
|
||||
stats.tcp_sessions,
|
||||
stats.ctrl_msg_count,
|
||||
stats.discarded_count,
|
||||
stats.sent_bytes,
|
||||
stats.recv_bytes,
|
||||
stats.dequeued_after_close
|
||||
);
|
||||
|
||||
tracing::info!("Successfully retrieved remoted stats");
|
||||
Self::success_result(vec![Content::text(formatted_text)])
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Manager", "retrieving remoted stats", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_weekly_stats(
|
||||
&self,
|
||||
_params: GetWeeklyStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
tracing::info!("Retrieving Wazuh weekly stats");
|
||||
|
||||
let mut logs_client = self.logs_client.lock().await;
|
||||
|
||||
match logs_client.get_weekly_stats().await {
|
||||
Ok(stats_value) => match serde_json::to_string_pretty(&stats_value) {
|
||||
Ok(formatted_json) => {
|
||||
tracing::info!("Successfully retrieved and formatted weekly stats.");
|
||||
Self::success_result(vec![Content::text(formatted_json)])
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error formatting weekly stats JSON: {}", e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Manager", "retrieving weekly stats", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_cluster_health(
|
||||
&self,
|
||||
_params: GetClusterHealthParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
tracing::info!("Retrieving Wazuh cluster health");
|
||||
|
||||
let mut cluster_client = self.cluster_client.lock().await;
|
||||
|
||||
match cluster_client.is_cluster_healthy().await {
|
||||
Ok(is_healthy) => {
|
||||
let health_status_text = if is_healthy {
|
||||
"Cluster is healthy: Yes".to_string()
|
||||
} else {
|
||||
// To provide more context, we can fetch the basic status
|
||||
match cluster_client.get_cluster_status().await {
|
||||
Ok(status) => {
|
||||
let mut reasons = Vec::new();
|
||||
if !status.enabled.eq_ignore_ascii_case("yes") {
|
||||
reasons.push("cluster is not enabled");
|
||||
}
|
||||
if !status.running.eq_ignore_ascii_case("yes") {
|
||||
reasons.push("cluster is not running");
|
||||
}
|
||||
if status.enabled.eq_ignore_ascii_case("yes") && status.running.eq_ignore_ascii_case("yes") {
|
||||
match cluster_client.get_cluster_healthcheck().await {
|
||||
Ok(hc) if hc.n_connected_nodes == 0 => reasons.push("no nodes are connected"),
|
||||
Err(_) => reasons.push("healthcheck endpoint failed or reported issues"),
|
||||
_ => {} // Healthy implies connected nodes
|
||||
}
|
||||
}
|
||||
if reasons.is_empty() && !is_healthy {
|
||||
reasons.push("unknown reason, check detailed logs or healthcheck endpoint");
|
||||
}
|
||||
format!("Cluster is healthy: No. Reasons: {}", reasons.join("; "))
|
||||
}
|
||||
Err(_) => {
|
||||
"Cluster is healthy: No. Additionally, failed to retrieve basic cluster status for more details.".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!(
|
||||
"Successfully retrieved cluster health: {}",
|
||||
health_status_text
|
||||
);
|
||||
Self::success_result(vec![Content::text(health_status_text)])
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Cluster", "retrieving health", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_cluster_nodes(
|
||||
&self,
|
||||
params: GetClusterNodesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
tracing::info!(
|
||||
limit = ?params.limit,
|
||||
offset = ?params.offset,
|
||||
node_type = ?params.node_type,
|
||||
"Retrieving Wazuh cluster nodes"
|
||||
);
|
||||
|
||||
let mut cluster_client = self.cluster_client.lock().await;
|
||||
|
||||
match cluster_client
|
||||
.get_cluster_nodes(params.limit, params.offset, params.node_type.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(nodes) => {
|
||||
if nodes.is_empty() {
|
||||
tracing::info!("No Wazuh cluster nodes found matching criteria. Returning standard message.");
|
||||
return Self::not_found_result("Wazuh cluster nodes");
|
||||
}
|
||||
|
||||
let num_nodes = nodes.len();
|
||||
let mcp_content_items: Vec<Content> = nodes
|
||||
.into_iter()
|
||||
.map(|node| {
|
||||
let status_indicator = match node.status.to_lowercase().as_str() {
|
||||
"connected" | "active" => "🟢 CONNECTED",
|
||||
"disconnected" => "🔴 DISCONNECTED",
|
||||
_ => &node.status,
|
||||
};
|
||||
let formatted_text = format!(
|
||||
"Node Name: {}\nType: {}\nVersion: {}\nIP: {}\nStatus: {}",
|
||||
node.name, node.node_type, node.version, node.ip, status_indicator
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} cluster nodes into {} MCP content items",
|
||||
num_nodes,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Self::success_result(mcp_content_items)
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = Self::format_error("Cluster", "retrieving nodes", &e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Self::error_result(err_msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for StatsTools {}
|
||||
|
389
src/tools/vulnerabilities.rs
Normal file
389
src/tools/vulnerabilities.rs
Normal file
@ -0,0 +1,389 @@
|
||||
//! Wazuh Manager vulnerability tools
|
||||
//!
|
||||
//! This module contains tools for retrieving and analyzing vulnerability information
|
||||
//! from the Wazuh Manager.
|
||||
|
||||
use rmcp::{
|
||||
Error as McpError,
|
||||
model::{CallToolResult, Content},
|
||||
schemars,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use wazuh_client::{VulnerabilityClient, VulnerabilitySeverity};
|
||||
use super::ToolModule;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetVulnerabilitiesSummaryParams {
|
||||
#[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 10000)")]
|
||||
pub limit: Option<u32>,
|
||||
#[schemars(description = "Agent ID to filter vulnerabilities by (required, e.g., \"0\", \"1\", \"001\")")]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Severity level to filter by (Low, Medium, High, Critical) (optional)")]
|
||||
pub severity: Option<String>,
|
||||
#[schemars(description = "CVE ID to search for (optional)")]
|
||||
pub cve: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
|
||||
pub struct GetCriticalVulnerabilitiesParams {
|
||||
#[schemars(description = "Agent ID to get critical vulnerabilities for (required, e.g., \"0\", \"1\", \"001\")")]
|
||||
pub agent_id: String,
|
||||
#[schemars(description = "Maximum number of vulnerabilities to retrieve (default: 300)")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VulnerabilityTools {
|
||||
vulnerability_client: Arc<Mutex<VulnerabilityClient>>,
|
||||
}
|
||||
|
||||
impl VulnerabilityTools {
|
||||
pub fn new(vulnerability_client: Arc<Mutex<VulnerabilityClient>>) -> Self {
|
||||
Self { vulnerability_client }
|
||||
}
|
||||
|
||||
fn format_agent_id(agent_id_str: &str) -> Result<String, String> {
|
||||
// Attempt to parse as a number first
|
||||
if let Ok(num) = agent_id_str.parse::<u32>() {
|
||||
if num > 999 {
|
||||
Err(format!(
|
||||
"Agent ID '{}' is too large. Must be a number between 0 and 999.",
|
||||
agent_id_str
|
||||
))
|
||||
} else {
|
||||
Ok(format!("{:03}", num))
|
||||
}
|
||||
} else if agent_id_str.len() == 3 && agent_id_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
// Already correctly formatted (e.g., "001")
|
||||
Ok(agent_id_str.to_string())
|
||||
} else {
|
||||
Err(format!(
|
||||
"Invalid agent_id format: '{}'. Must be a number (e.g., 1, 12) or a 3-digit string (e.g., 001, 012).",
|
||||
agent_id_str
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_vulnerability_summary(
|
||||
&self,
|
||||
params: GetVulnerabilitiesSummaryParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(10000);
|
||||
let offset = 0; // Default offset, can be extended in future if needed
|
||||
|
||||
let agent_id = match Self::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
tracing::error!(
|
||||
"Error formatting agent_id for vulnerability summary: {}",
|
||||
err_msg
|
||||
);
|
||||
return Ok(CallToolResult::error(vec![Content::text(err_msg)]));
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
limit = %limit,
|
||||
agent_id = %agent_id,
|
||||
severity = ?params.severity,
|
||||
cve = ?params.cve,
|
||||
"Retrieving Wazuh vulnerability summary"
|
||||
);
|
||||
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
let vulnerabilities = vulnerability_client
|
||||
.get_agent_vulnerabilities(
|
||||
&agent_id,
|
||||
Some(limit),
|
||||
Some(offset),
|
||||
params
|
||||
.severity
|
||||
.as_deref()
|
||||
.and_then(VulnerabilitySeverity::from_str),
|
||||
)
|
||||
.await;
|
||||
|
||||
match vulnerabilities {
|
||||
Ok(vulnerabilities) => {
|
||||
if vulnerabilities.is_empty() {
|
||||
tracing::info!("No Wazuh vulnerabilities found matching criteria. Returning standard message.");
|
||||
return Ok(CallToolResult::success(vec![Content::text(
|
||||
"No Wazuh vulnerabilities found matching the specified criteria.",
|
||||
)]));
|
||||
}
|
||||
|
||||
let num_vulnerabilities = vulnerabilities.len();
|
||||
let mcp_content_items: Vec<Content> = vulnerabilities
|
||||
.into_iter()
|
||||
.map(|vuln| {
|
||||
let severity_indicator = match vuln.severity {
|
||||
VulnerabilitySeverity::Critical => "🔴 CRITICAL",
|
||||
VulnerabilitySeverity::High => "🟠 HIGH",
|
||||
VulnerabilitySeverity::Medium => "🟡 MEDIUM",
|
||||
VulnerabilitySeverity::Low => "🟢 LOW",
|
||||
};
|
||||
|
||||
let published_info = if let Some(published) = &vuln.published {
|
||||
format!("\nPublished: {}", published)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let updated_info = if let Some(updated) = &vuln.updated {
|
||||
format!("\nUpdated: {}", updated)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let detection_time_info = if let Some(detection_time) = &vuln.detection_time
|
||||
{
|
||||
format!("\nDetection Time: {}", detection_time)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let agent_info = {
|
||||
let id_str = vuln.agent_id.as_deref();
|
||||
let name_str = vuln.agent_name.as_deref();
|
||||
|
||||
match (id_str, name_str) {
|
||||
(Some("000"), Some(name)) => {
|
||||
format!("\nAgent: {} (Wazuh Manager, ID: 000)", name)
|
||||
}
|
||||
(Some("000"), None) => {
|
||||
"\nAgent: Wazuh Manager (ID: 000)".to_string()
|
||||
}
|
||||
(Some(id), Some(name)) => format!("\nAgent: {} (ID: {})", name, id),
|
||||
(Some(id), None) => format!("\nAgent ID: {}", id),
|
||||
(None, Some(name)) => format!("\nAgent: {} (ID: Unknown)", name), // Should ideally not happen if ID is a primary key for agent context
|
||||
(None, None) => String::new(), // No agent information available
|
||||
}
|
||||
};
|
||||
|
||||
let cvss_info = if let Some(cvss) = &vuln.cvss {
|
||||
let mut cvss_parts = Vec::new();
|
||||
if let Some(cvss2) = &cvss.cvss2 {
|
||||
if let Some(score) = cvss2.base_score {
|
||||
cvss_parts.push(format!("CVSS2: {}", score));
|
||||
}
|
||||
}
|
||||
if let Some(cvss3) = &cvss.cvss3 {
|
||||
if let Some(score) = cvss3.base_score {
|
||||
cvss_parts.push(format!("CVSS3: {}", score));
|
||||
}
|
||||
}
|
||||
if !cvss_parts.is_empty() {
|
||||
format!("\nCVSS Scores: {}", cvss_parts.join(", "))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let reference_info = if let Some(reference) = &vuln.reference {
|
||||
format!("\nReference: {}", reference)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let description = vuln
|
||||
.description
|
||||
.as_deref()
|
||||
.unwrap_or("No description available");
|
||||
|
||||
let formatted_text = format!(
|
||||
"CVE: {}\nSeverity: {}\nTitle: {}\nDescription: {}{}{}{}{}{}{}",
|
||||
vuln.cve,
|
||||
severity_indicator,
|
||||
vuln.title,
|
||||
description,
|
||||
published_info,
|
||||
updated_info,
|
||||
detection_time_info,
|
||||
agent_info,
|
||||
cvss_info,
|
||||
reference_info
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} vulnerabilities into {} MCP content items",
|
||||
num_vulnerabilities,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Ok(CallToolResult::success(mcp_content_items))
|
||||
}
|
||||
Err(e) => {
|
||||
use reqwest::StatusCode;
|
||||
match e {
|
||||
wazuh_client::WazuhApiError::HttpError {
|
||||
status,
|
||||
message: _,
|
||||
url: _,
|
||||
} if status == StatusCode::NOT_FOUND => {
|
||||
tracing::info!("No vulnerability summary found for agent {}. Returning standard message.", agent_id);
|
||||
return Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"No vulnerability summary found for agent {}.",
|
||||
agent_id
|
||||
))]));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let err_msg = format!("Error retrieving vulnerabilities from Wazuh: {}", e);
|
||||
tracing::error!("{}", err_msg);
|
||||
Ok(CallToolResult::error(vec![Content::text(err_msg)]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_wazuh_critical_vulnerabilities(
|
||||
&self,
|
||||
params: GetCriticalVulnerabilitiesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = params.limit.unwrap_or(300);
|
||||
let agent_id = match Self::format_agent_id(¶ms.agent_id) {
|
||||
Ok(formatted_id) => formatted_id,
|
||||
Err(err_msg) => {
|
||||
tracing::error!(
|
||||
"Error formatting agent_id for critical vulnerabilities: {}",
|
||||
err_msg
|
||||
);
|
||||
return Ok(CallToolResult::error(vec![Content::text(err_msg)]));
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
agent_id = %agent_id,
|
||||
"Retrieving critical vulnerabilities for Wazuh agent"
|
||||
);
|
||||
|
||||
let mut vulnerability_client = self.vulnerability_client.lock().await;
|
||||
|
||||
match vulnerability_client
|
||||
.get_critical_vulnerabilities(&agent_id, Some(limit))
|
||||
.await
|
||||
{
|
||||
Ok(vulnerabilities) => {
|
||||
if vulnerabilities.is_empty() {
|
||||
tracing::info!("No critical vulnerabilities found for agent {}. Returning standard message.", agent_id);
|
||||
return Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"No critical vulnerabilities found for agent {}.",
|
||||
agent_id
|
||||
))]));
|
||||
}
|
||||
|
||||
let num_vulnerabilities = vulnerabilities.len();
|
||||
let mcp_content_items: Vec<Content> = vulnerabilities
|
||||
.into_iter()
|
||||
.map(|vuln| {
|
||||
let published_info = if let Some(published) = &vuln.published {
|
||||
format!("\nPublished: {}", published)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let updated_info = if let Some(updated) = &vuln.updated {
|
||||
format!("\nUpdated: {}", updated)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let detection_time_info = if let Some(detection_time) = &vuln.detection_time {
|
||||
format!("\nDetection Time: {}", detection_time)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let agent_info = if let Some(agent_name) = &vuln.agent_name {
|
||||
format!("\nAgent: {} (ID: {})", agent_name, vuln.agent_id.as_deref().unwrap_or("Unknown"))
|
||||
} else if let Some(agent_id) = &vuln.agent_id {
|
||||
format!("\nAgent ID: {}", agent_id)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let cvss_info = if let Some(cvss) = &vuln.cvss {
|
||||
let mut cvss_parts = Vec::new();
|
||||
if let Some(cvss2) = &cvss.cvss2 {
|
||||
if let Some(score) = cvss2.base_score {
|
||||
cvss_parts.push(format!("CVSS2: {}", score));
|
||||
}
|
||||
}
|
||||
if let Some(cvss3) = &cvss.cvss3 {
|
||||
if let Some(score) = cvss3.base_score {
|
||||
cvss_parts.push(format!("CVSS3: {}", score));
|
||||
}
|
||||
}
|
||||
if !cvss_parts.is_empty() {
|
||||
format!("\nCVSS Scores: {}", cvss_parts.join(", "))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let reference_info = if let Some(reference) = &vuln.reference {
|
||||
format!("\nReference: {}", reference)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let description = vuln.description.as_deref().unwrap_or("No description available");
|
||||
|
||||
let formatted_text = format!(
|
||||
"🔴 CRITICAL VULNERABILITY\nCVE: {}\nTitle: {}\nDescription: {}{}{}{}{}{}{}",
|
||||
vuln.cve,
|
||||
vuln.title,
|
||||
description,
|
||||
published_info,
|
||||
updated_info,
|
||||
detection_time_info,
|
||||
agent_info,
|
||||
cvss_info,
|
||||
reference_info
|
||||
);
|
||||
Content::text(formatted_text)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
"Successfully processed {} critical vulnerabilities into {} MCP content items",
|
||||
num_vulnerabilities,
|
||||
mcp_content_items.len()
|
||||
);
|
||||
Ok(CallToolResult::success(mcp_content_items))
|
||||
}
|
||||
Err(e) => {
|
||||
use reqwest::StatusCode;
|
||||
match e {
|
||||
wazuh_client::WazuhApiError::HttpError {
|
||||
status,
|
||||
message: _,
|
||||
url: _,
|
||||
} if status == StatusCode::NOT_FOUND => {
|
||||
tracing::info!("No critical vulnerabilities found for agent {}. Returning standard message.", agent_id);
|
||||
return Ok(CallToolResult::success(vec![Content::text(format!(
|
||||
"No critical vulnerabilities found for agent {}.",
|
||||
agent_id
|
||||
))]));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let err_msg = format!(
|
||||
"Error retrieving critical vulnerabilities from Wazuh for agent {}: {}",
|
||||
agent_id, e
|
||||
);
|
||||
tracing::error!("{}", err_msg);
|
||||
Ok(CallToolResult::error(vec![Content::text(err_msg)]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolModule for VulnerabilityTools {}
|
@ -1,129 +0,0 @@
|
||||
use reqwest::{header, Client, Method};
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::error::WazuhApiError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WazuhIndexerClient {
|
||||
username: String,
|
||||
password: String,
|
||||
base_url: String,
|
||||
http_client: Client,
|
||||
}
|
||||
|
||||
// Renamed impl block
|
||||
impl WazuhIndexerClient {
|
||||
pub fn new(
|
||||
host: String,
|
||||
indexer_port: u16,
|
||||
username: String,
|
||||
password: String,
|
||||
verify_ssl: bool,
|
||||
) -> Self {
|
||||
debug!(%host, indexer_port, %username, %verify_ssl, "Creating new WazuhIndexerClient");
|
||||
// Base URL now points to the Indexer
|
||||
let base_url = format!("https://{}:{}", host, indexer_port);
|
||||
debug!(%base_url, "Wazuh Indexer base URL set");
|
||||
|
||||
let http_client = Client::builder()
|
||||
.danger_accept_invalid_certs(!verify_ssl)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
username,
|
||||
password,
|
||||
base_url,
|
||||
http_client,
|
||||
}
|
||||
}
|
||||
|
||||
async fn make_indexer_request(
|
||||
&self,
|
||||
method: Method,
|
||||
endpoint: &str,
|
||||
body: Option<Value>,
|
||||
) -> Result<Value, WazuhApiError> {
|
||||
debug!(?method, %endpoint, ?body, "Making request to Wazuh Indexer");
|
||||
let url = format!("{}{}", self.base_url, endpoint);
|
||||
debug!(%url, "Constructed Indexer request URL");
|
||||
|
||||
let mut request_builder = self
|
||||
.http_client
|
||||
.request(method.clone(), &url)
|
||||
.basic_auth(&self.username, Some(&self.password)); // Use Basic Auth
|
||||
|
||||
if let Some(json_body) = &body {
|
||||
request_builder = request_builder
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.json(json_body);
|
||||
}
|
||||
debug!("Request builder configured with Basic Auth");
|
||||
|
||||
let response = request_builder.send().await?;
|
||||
let status = response.status();
|
||||
debug!(%status, "Received response from Indexer endpoint");
|
||||
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error reading response body".to_string());
|
||||
error!(%url, %status, %error_text, "Indexer API request failed");
|
||||
// Provide more context in the error
|
||||
return Err(WazuhApiError::ApiError(format!(
|
||||
"Indexer request to {} failed with status {}: {}",
|
||||
url, status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
debug!("Indexer API request successful");
|
||||
response.json().await.map_err(|e| {
|
||||
error!("Failed to parse JSON response from Indexer: {}", e);
|
||||
WazuhApiError::RequestError(e) // Use appropriate error variant
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_alerts(&self) -> Result<Vec<Value>, WazuhApiError> {
|
||||
let endpoint = "/wazuh-alerts*/_search";
|
||||
let query_body = json!({
|
||||
"size": 100,
|
||||
"query": {
|
||||
"match_all": {}
|
||||
},
|
||||
});
|
||||
|
||||
debug!(%endpoint, ?query_body, "Preparing to get alerts from Wazuh Indexer");
|
||||
info!("Retrieving up to 100 alerts from Wazuh Indexer");
|
||||
|
||||
let response = self
|
||||
.make_indexer_request(Method::POST, endpoint, Some(query_body))
|
||||
.await?;
|
||||
|
||||
let hits = response
|
||||
.get("hits")
|
||||
.and_then(|h| h.get("hits"))
|
||||
.and_then(|h_array| h_array.as_array())
|
||||
.ok_or_else(|| {
|
||||
error!(
|
||||
?response,
|
||||
"Failed to find 'hits.hits' array in Indexer response"
|
||||
);
|
||||
WazuhApiError::ApiError("Indexer response missing 'hits.hits' array".to_string())
|
||||
})?;
|
||||
|
||||
let alerts: Vec<Value> = hits
|
||||
.iter()
|
||||
.filter_map(|hit| hit.get("_source").cloned())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"Successfully retrieved {} alerts from Indexer",
|
||||
alerts.len()
|
||||
);
|
||||
Ok(alerts)
|
||||
}
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum WazuhApiError {
|
||||
#[error("Failed to create HTTP client: {0}")]
|
||||
HttpClientCreationError(reqwest::Error),
|
||||
|
||||
#[error("HTTP request error: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("JWT token not found in Wazuh API response")]
|
||||
JwtNotFound,
|
||||
|
||||
#[error("Wazuh API Authentication failed: {0}")]
|
||||
AuthenticationError(String),
|
||||
|
||||
#[error("Failed to decode JWT: {0}")]
|
||||
JwtDecodingError(#[from] jsonwebtoken::errors::Error),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Wazuh API error: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Alert with ID '{0}' not found")]
|
||||
AlertNotFound(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for WazuhApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
WazuhApiError::HttpClientCreationError(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Internal server error: {}", e),
|
||||
),
|
||||
WazuhApiError::RequestError(e) => {
|
||||
if e.is_connect() || e.is_timeout() {
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Could not connect to Wazuh API: {}", e),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Wazuh API request error: {}", e),
|
||||
)
|
||||
}
|
||||
}
|
||||
WazuhApiError::JwtNotFound => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to retrieve JWT token from Wazuh API response".to_string(),
|
||||
),
|
||||
WazuhApiError::AuthenticationError(msg) => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
format!("Wazuh API authentication failed: {}", msg),
|
||||
),
|
||||
WazuhApiError::JwtDecodingError(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to decode JWT token: {}", e),
|
||||
),
|
||||
WazuhApiError::JsonError(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to parse Wazuh API response: {}", e),
|
||||
),
|
||||
WazuhApiError::ApiError(msg) => {
|
||||
(StatusCode::BAD_GATEWAY, format!("Wazuh API error: {}", msg))
|
||||
}
|
||||
WazuhApiError::AlertNotFound(id) => (
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Alert with ID '{}' not found", id),
|
||||
),
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
pub mod client;
|
||||
pub mod error;
|
141
tests/README.md
141
tests/README.md
@ -1,60 +1,133 @@
|
||||
# Wazuh MCP Server Tests
|
||||
|
||||
This directory contains tests for the Wazuh MCP Server, including end-to-end tests that simulate a client interacting with the server.
|
||||
This directory contains tests for the Wazuh MCP Server using the rmcp framework, including unit tests, integration tests with mock Wazuh API, and end-to-end MCP protocol tests.
|
||||
|
||||
## Test Files
|
||||
|
||||
- `e2e_client_test.rs`: End-to-end test for MCP client interacting with Wazuh MCP server
|
||||
- `integration_test.rs`: Integration test for Wazuh MCP Server with a mock Wazuh API
|
||||
- `mcp_client.rs`: Reusable MCP client implementation
|
||||
- `mcp_client_cli.rs`: Command-line tool for interacting with the MCP server
|
||||
- `rmcp_integration_test.rs`: Integration tests for the rmcp-based MCP server using a mock Wazuh API.
|
||||
- `mock_wazuh_server.rs`: Mock Wazuh API server implementation, used by the integration tests.
|
||||
- `mcp_stdio_test.rs`: Tests for MCP protocol communication via stdio, focusing on initialization, compliance, concurrent requests, and error handling for invalid/unsupported messages.
|
||||
- `run_tests.sh`: A shell script that automates running the various test suites.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### 1. Mock Wazuh Server Tests
|
||||
Tests the MCP server with a mock Wazuh API to verify:
|
||||
- Tool registration and schema generation
|
||||
- Alert retrieval and formatting
|
||||
- Error handling for API failures
|
||||
- Parameter validation
|
||||
|
||||
### 2. MCP Protocol Tests
|
||||
Tests the MCP protocol implementation (primarily in `mcp_stdio_test.rs`):
|
||||
- Initialize handshake.
|
||||
- Tools listing (basic, without requiring a live Wazuh connection).
|
||||
- Handling of invalid JSON-RPC requests and unsupported methods.
|
||||
- Behavior with concurrent requests.
|
||||
- JSON-RPC 2.0 compliance.
|
||||
(Note: Full tool execution, like `tools/call`, is primarily tested in `rmcp_integration_test.rs` using the mock Wazuh server.)
|
||||
|
||||
### 3. Unit Tests
|
||||
Tests individual components and modules, typically run via `cargo test --lib`. These may include:
|
||||
- Wazuh client logic (e.g., authentication, request formation, response parsing).
|
||||
- Alert data transformation and formatting.
|
||||
- Internal error handling mechanisms and utility functions.
|
||||
|
||||
## Running the Tests
|
||||
|
||||
To run all tests:
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
To run a specific test:
|
||||
|
||||
### Run Specific Test Categories
|
||||
```bash
|
||||
cargo test --test e2e_client_test
|
||||
cargo test --test integration_test
|
||||
# Integration tests with mock Wazuh
|
||||
cargo test --test rmcp_integration_test
|
||||
|
||||
# MCP protocol tests
|
||||
cargo test --test mcp_stdio_test
|
||||
|
||||
# Unit tests
|
||||
cargo test --lib
|
||||
```
|
||||
|
||||
## Using the MCP Client CLI
|
||||
|
||||
The MCP Client CLI can be used to interact with the MCP server for testing purposes:
|
||||
|
||||
### Run Tests with Logging
|
||||
```bash
|
||||
# Build the CLI
|
||||
cargo build --bin mcp_client_cli
|
||||
|
||||
# Run the CLI
|
||||
MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli get-data
|
||||
MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli health
|
||||
MCP_SERVER_URL=http://localhost:8000 ./target/debug/mcp_client_cli query '{"severity": "high"}'
|
||||
RUST_LOG=debug cargo test -- --nocapture
|
||||
```
|
||||
|
||||
## Test Environment Variables
|
||||
|
||||
The tests use the following environment variables:
|
||||
The tests support the following environment variables:
|
||||
|
||||
- `MCP_SERVER_URL`: URL of the MCP server (default: http://localhost:8000)
|
||||
- `WAZUH_HOST`: Hostname of the Wazuh API server
|
||||
- `WAZUH_PORT`: Port of the Wazuh API server
|
||||
- `WAZUH_USER`: Username for Wazuh API authentication
|
||||
- `WAZUH_PASS`: Password for Wazuh API authentication
|
||||
- `VERIFY_SSL`: Whether to verify SSL certificates (default: false)
|
||||
- `RUST_LOG`: Log level for the tests (default: info)
|
||||
- `RUST_LOG`: Log level for tests (default: info)
|
||||
- `TEST_WAZUH_HOST`: Real Wazuh host for integration tests (optional)
|
||||
- `TEST_WAZUH_PORT`: Real Wazuh port for integration tests (optional)
|
||||
- `TEST_WAZUH_USER`: Real Wazuh username for integration tests (optional)
|
||||
- `TEST_WAZUH_PASS`: Real Wazuh password for integration tests (optional)
|
||||
|
||||
## Mock Wazuh API Server
|
||||
|
||||
The tests use a mock Wazuh API server to simulate the Wazuh API. The mock server provides:
|
||||
The mock server simulates a real Wazuh Indexer API with:
|
||||
|
||||
- Authentication endpoint: `/security/user/authenticate`
|
||||
- Alerts endpoint: `/wazuh-alerts-*_search`
|
||||
### Authentication Endpoint
|
||||
- `POST /security/user/authenticate`
|
||||
- Returns mock JWT token
|
||||
|
||||
The mock server returns predefined responses for these endpoints, allowing the tests to run without a real Wazuh API server.
|
||||
### Alerts Endpoint
|
||||
- `POST /wazuh-alerts-*/_search` (Note: The Wazuh API typically uses POST for search queries with a body)
|
||||
- Returns configurable mock alert data
|
||||
- Supports different scenarios (success, empty, error)
|
||||
|
||||
### Configurable Responses
|
||||
The mock server can be configured to return:
|
||||
- Successful responses with sample alerts
|
||||
- Empty responses (no alerts)
|
||||
- Error responses (500, 401, etc.)
|
||||
- Malformed responses for error testing
|
||||
|
||||
## Testing Without Real Wazuh
|
||||
|
||||
All tests can run without a real Wazuh instance by using the mock server. This allows for:
|
||||
|
||||
- **CI/CD Integration**: Tests run in any environment
|
||||
- **Deterministic Results**: Predictable test data
|
||||
- **Error Scenario Testing**: Simulate various failure modes
|
||||
- **Fast Execution**: No network dependencies
|
||||
|
||||
## Testing With a Real Wazuh Instance (Manual End-to-End)
|
||||
|
||||
The automated test suites (`cargo test`) use mock servers or no Wazuh connection. To perform end-to-end testing with a real Wazuh instance, you need to run the server application itself and interact with it manually or via a separate client.
|
||||
|
||||
1. **Set up your Wazuh environment:** Ensure you have a running Wazuh instance (Indexer/API).
|
||||
2. **Configure Environment Variables:** Set the standard runtime environment variables for the server to connect to your Wazuh instance:
|
||||
```bash
|
||||
export WAZUH_HOST="your-wazuh-indexer-host" # e.g., localhost or an IP address
|
||||
export WAZUH_PORT="9200" # Or your Wazuh Indexer port
|
||||
export WAZUH_USER="your-wazuh-api-user"
|
||||
export WAZUH_PASS="your-wazuh-api-password"
|
||||
export VERIFY_SSL="false" # Set to "true" if your Wazuh API uses a valid CA-signed SSL certificate
|
||||
# export RUST_LOG="debug" # For more detailed server logs
|
||||
```
|
||||
|
||||
## Manual Testing
|
||||
|
||||
### Using stdio directly
|
||||
The server communicates over stdin/stdout. You can send commands by piping them to the process:
|
||||
```bash
|
||||
# Example: Send an initialize request
|
||||
echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}' | cargo run --bin mcp-server-wazuh
|
||||
```
|
||||
|
||||
### Using the test script
|
||||
```bash
|
||||
# Run the provided test script
|
||||
./tests/run_tests.sh
|
||||
```
|
||||
|
||||
This script will:
|
||||
1. Start the MCP server with mock Wazuh configuration
|
||||
2. Send a series of MCP commands
|
||||
3. Verify responses
|
||||
4. Clean up processes
|
||||
|
@ -1,194 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use httpmock::prelude::*;
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::process::{Child, Command};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use uuid::Uuid;
|
||||
|
||||
struct MockWazuhServer {
|
||||
server: MockServer,
|
||||
}
|
||||
|
||||
impl MockWazuhServer {
|
||||
fn new() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/security/user/authenticate")
|
||||
.header("Authorization", "Basic YWRtaW46YWRtaW4=");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET)
|
||||
.path("/wazuh-alerts-*_search")
|
||||
.header("Authorization", "Bearer mock.jwt.token");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "12345",
|
||||
"category": "intrusion_detection",
|
||||
"severity": "high",
|
||||
"description": "Possible intrusion attempt detected",
|
||||
"data": {
|
||||
"source_ip": "192.168.1.100",
|
||||
"destination_ip": "10.0.0.1",
|
||||
"port": 22
|
||||
},
|
||||
"notes": "Test alert"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "67890",
|
||||
"category": "malware",
|
||||
"severity": "critical",
|
||||
"description": "Malware detected on system",
|
||||
"data": {
|
||||
"file_path": "/tmp/malicious.exe",
|
||||
"hash": "abcdef123456",
|
||||
"signature": "EICAR-Test-File"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
fn url(&self) -> String {
|
||||
self.server.url("")
|
||||
}
|
||||
}
|
||||
|
||||
struct McpClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
fn new(base_url: String) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self { client, base_url }
|
||||
}
|
||||
|
||||
async fn get_mcp_data(&self) -> Result<Vec<Value>> {
|
||||
let url = format!("{}/mcp", self.base_url);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("MCP request failed with status: {}", response.status());
|
||||
}
|
||||
|
||||
let data = response.json::<Vec<Value>>().await?;
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
async fn check_health(&self) -> Result<Value> {
|
||||
let url = format!("{}/health", self.base_url);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Health check failed with status: {}", response.status());
|
||||
}
|
||||
|
||||
let data = response.json::<Value>().await?;
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
|
||||
fn start_mcp_server(wazuh_url: &str, port: u16) -> Child {
|
||||
let server_id = Uuid::new_v4().to_string();
|
||||
let wazuh_host_port: Vec<&str> = wazuh_url.trim_start_matches("http://").split(':').collect();
|
||||
let wazuh_host = wazuh_host_port[0];
|
||||
let wazuh_port = wazuh_host_port[1];
|
||||
|
||||
Command::new("cargo")
|
||||
.args(["run", "--"])
|
||||
.env("WAZUH_HOST", wazuh_host)
|
||||
.env("WAZUH_PORT", wazuh_port)
|
||||
.env("WAZUH_USER", "admin")
|
||||
.env("WAZUH_PASS", "admin")
|
||||
.env("VERIFY_SSL", "false")
|
||||
.env("MCP_SERVER_PORT", port.to_string())
|
||||
.env("RUST_LOG", "info")
|
||||
.env("SERVER_ID", server_id)
|
||||
.spawn()
|
||||
.expect("Failed to start MCP server")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_integration() -> Result<()> {
|
||||
let mock_wazuh = MockWazuhServer::new();
|
||||
let wazuh_url = mock_wazuh.url();
|
||||
|
||||
let mcp_port = 8765;
|
||||
let mut mcp_server = start_mcp_server(&wazuh_url, mcp_port);
|
||||
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port));
|
||||
|
||||
let health_data = mcp_client.check_health().await?;
|
||||
assert_eq!(health_data["status"], "ok");
|
||||
assert_eq!(health_data["service"], "wazuh-mcp-server");
|
||||
|
||||
let mcp_data = mcp_client.get_mcp_data().await?;
|
||||
|
||||
assert_eq!(mcp_data.len(), 2);
|
||||
|
||||
let first_message = &mcp_data[0];
|
||||
assert_eq!(first_message["protocol_version"], "1.0");
|
||||
assert_eq!(first_message["source"], "Wazuh");
|
||||
assert_eq!(first_message["event_type"], "alert");
|
||||
|
||||
let context = &first_message["context"];
|
||||
assert_eq!(context["id"], "12345");
|
||||
assert_eq!(context["category"], "intrusion_detection");
|
||||
assert_eq!(context["severity"], "high");
|
||||
assert_eq!(
|
||||
context["description"],
|
||||
"Possible intrusion attempt detected"
|
||||
);
|
||||
|
||||
let data = &context["data"];
|
||||
assert_eq!(data["source_ip"], "192.168.1.100");
|
||||
assert_eq!(data["destination_ip"], "10.0.0.1");
|
||||
assert_eq!(data["port"], 22);
|
||||
|
||||
let second_message = &mcp_data[1];
|
||||
let context = &second_message["context"];
|
||||
assert_eq!(context["id"], "67890");
|
||||
assert_eq!(context["category"], "malware");
|
||||
assert_eq!(context["severity"], "critical");
|
||||
assert_eq!(context["description"], "Malware detected on system");
|
||||
|
||||
let data = &context["data"];
|
||||
assert_eq!(data["file_path"], "/tmp/malicious.exe");
|
||||
assert_eq!(data["hash"], "abcdef123456");
|
||||
assert_eq!(data["signature"], "EICAR-Test-File");
|
||||
|
||||
mcp_server.kill().expect("Failed to kill MCP server");
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,420 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use httpmock::prelude::*;
|
||||
use once_cell::sync::Lazy;
|
||||
use serde_json::json;
|
||||
use std::net::TcpListener;
|
||||
use std::process::{Child, Command};
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
mod mcp_client;
|
||||
use mcp_client::{McpClient, McpClientTrait, McpMessage};
|
||||
|
||||
static TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
|
||||
|
||||
fn find_available_port() -> u16 {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to random port");
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.expect("Failed to get local address")
|
||||
.port();
|
||||
drop(listener);
|
||||
port
|
||||
}
|
||||
|
||||
struct MockWazuhServer {
|
||||
server: MockServer,
|
||||
}
|
||||
|
||||
impl MockWazuhServer {
|
||||
fn new() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/wazuh-alerts-*_search");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "12345",
|
||||
"category": "intrusion_detection",
|
||||
"severity": "high",
|
||||
"description": "Possible intrusion attempt detected",
|
||||
"data": {
|
||||
"source_ip": "192.168.1.100",
|
||||
"destination_ip": "10.0.0.1",
|
||||
"port": 22
|
||||
},
|
||||
"notes": "Test alert"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "67890",
|
||||
"category": "malware",
|
||||
"severity": "critical",
|
||||
"description": "Malware detected on system",
|
||||
"data": {
|
||||
"file_path": "/tmp/malicious.exe",
|
||||
"hash": "abcdef123456",
|
||||
"signature": "EICAR-Test-File"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
fn url(&self) -> String {
|
||||
self.server.url("")
|
||||
}
|
||||
|
||||
fn host(&self) -> String {
|
||||
let url = self.url();
|
||||
let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect();
|
||||
parts[0].to_string()
|
||||
}
|
||||
|
||||
fn port(&self) -> u16 {
|
||||
let url = self.url();
|
||||
let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect();
|
||||
parts[1].parse().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn setup_mock_wazuh_server() -> MockServer {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({ "jwt": "mock.jwt.token" }));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap());
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "12345",
|
||||
"timestamp": "2024-01-01T10:00:00.000Z",
|
||||
"rule": {
|
||||
"level": 9,
|
||||
"description": "Possible intrusion attempt detected",
|
||||
"groups": ["intrusion_detection", "pci_dss"]
|
||||
},
|
||||
"agent": { "id": "001", "name": "test-agent" },
|
||||
"data": {
|
||||
"source_ip": "192.168.1.100",
|
||||
"destination_ip": "10.0.0.1",
|
||||
"port": 22
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "67890",
|
||||
"timestamp": "2024-01-01T11:00:00.000Z",
|
||||
"rule": {
|
||||
"level": 12,
|
||||
"description": "Malware detected on system",
|
||||
"groups": ["malware"]
|
||||
},
|
||||
"agent": { "id": "002", "name": "another-agent" },
|
||||
"data": {
|
||||
"file_path": "/tmp/malicious.exe",
|
||||
"hash": "abcdef123456",
|
||||
"signature": "EICAR-Test-File"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
server
|
||||
}
|
||||
|
||||
fn get_host_port(server: &MockServer) -> (String, u16) {
|
||||
let url = server.url("");
|
||||
let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect();
|
||||
let host = parts[0].to_string();
|
||||
let port = parts[1].parse().unwrap();
|
||||
(host, port)
|
||||
}
|
||||
|
||||
fn start_mcp_server(wazuh_host: &str, wazuh_port: u16, mcp_port: u16) -> Child {
|
||||
Command::new("cargo")
|
||||
.args(["run", "--"])
|
||||
.env("WAZUH_HOST", wazuh_host)
|
||||
.env("WAZUH_PORT", wazuh_port.to_string())
|
||||
.env("WAZUH_USER", "admin")
|
||||
.env("WAZUH_PASS", "admin")
|
||||
.env("VERIFY_SSL", "false")
|
||||
.env("MCP_SERVER_PORT", mcp_port.to_string())
|
||||
.env("RUST_LOG", "info")
|
||||
.spawn()
|
||||
.expect("Failed to start MCP server")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
async fn test_mcp_server_with_mock_wazuh() -> Result<()> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_wazuh_server = setup_mock_wazuh_server();
|
||||
let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server);
|
||||
|
||||
let mcp_port = find_available_port();
|
||||
|
||||
let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port);
|
||||
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port));
|
||||
|
||||
let health_data = mcp_client.check_health().await?;
|
||||
assert_eq!(health_data["status"], "ok");
|
||||
assert_eq!(health_data["service"], "wazuh-mcp-server");
|
||||
|
||||
let mcp_data = mcp_client.get_mcp_data().await?;
|
||||
|
||||
assert_eq!(mcp_data.len(), 2);
|
||||
|
||||
let first_message: &McpMessage = &mcp_data[0];
|
||||
assert_eq!(first_message.protocol_version, "1.0");
|
||||
assert_eq!(first_message.source, "Wazuh");
|
||||
assert_eq!(first_message.event_type, "alert");
|
||||
|
||||
let context = &first_message.context;
|
||||
assert_eq!(context["id"], "12345");
|
||||
assert_eq!(context["category"], "intrusion_detection");
|
||||
assert_eq!(context["severity"], "high");
|
||||
assert_eq!(
|
||||
context["description"],
|
||||
"Possible intrusion attempt detected"
|
||||
);
|
||||
assert_eq!(context["agent"]["name"], "test-agent");
|
||||
|
||||
let data = &context["data"];
|
||||
assert_eq!(data["source_ip"], "192.168.1.100");
|
||||
assert_eq!(data["destination_ip"], "10.0.0.1");
|
||||
assert_eq!(data["port"], 22);
|
||||
|
||||
let second_message = &mcp_data[1];
|
||||
let context = &second_message.context;
|
||||
assert_eq!(context["id"], "67890");
|
||||
assert_eq!(context["category"], "malware");
|
||||
assert_eq!(context["severity"], "critical");
|
||||
assert_eq!(context["description"], "Malware detected on system");
|
||||
assert_eq!(context["agent"]["name"], "another-agent");
|
||||
|
||||
let data = &context["data"];
|
||||
assert_eq!(data["file_path"], "/tmp/malicious.exe");
|
||||
assert_eq!(data["hash"], "abcdef123456");
|
||||
assert_eq!(data["signature"], "EICAR-Test-File");
|
||||
|
||||
mcp_server.kill().expect("Failed to kill MCP server");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_server_wazuh_api_error() -> Result<()> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_wazuh_server = setup_mock_wazuh_server();
|
||||
let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server);
|
||||
|
||||
mock_wazuh_server.mock(|when, then| {
|
||||
when.method(GET)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap());
|
||||
then.status(500)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({"error": "Wazuh internal error"}));
|
||||
});
|
||||
|
||||
let mcp_port = find_available_port();
|
||||
let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port);
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port));
|
||||
|
||||
let result = mcp_client.get_mcp_data().await;
|
||||
assert!(result.is_err());
|
||||
let err_string = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_string.contains("500")
|
||||
|| err_string.contains("502")
|
||||
|| err_string.contains("API request failed")
|
||||
);
|
||||
|
||||
let health_result = mcp_client.check_health().await;
|
||||
assert!(health_result.is_ok());
|
||||
assert_eq!(health_result.unwrap()["status"], "ok");
|
||||
|
||||
mcp_server.kill().expect("Failed to kill MCP server");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_error_handling() -> Result<()> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/mcp");
|
||||
then.status(500)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "Internal server error"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/health");
|
||||
then.status(503)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "Service unavailable"
|
||||
}));
|
||||
});
|
||||
|
||||
let client = McpClient::new(server.url(""));
|
||||
|
||||
let result = client.get_mcp_data().await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("500") || err.to_string().contains("MCP request failed"));
|
||||
|
||||
let result = client.check_health().await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("503") || err.to_string().contains("Health check failed"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_server_missing_alert_data() -> Result<()> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_wazuh_server = setup_mock_wazuh_server();
|
||||
let (wazuh_host, wazuh_port) = get_host_port(&mock_wazuh_server);
|
||||
|
||||
mock_wazuh_server.mock(|when, then| {
|
||||
when.method(GET)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts-.*_search").unwrap());
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "missing_all",
|
||||
"timestamp": "invalid-date-format"
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "missing_rule_fields",
|
||||
"timestamp": "2024-05-05T11:00:00.000Z",
|
||||
"rule": { },
|
||||
"agent": { "id": "003", "name": "agent-minimal" },
|
||||
"data": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "no_source_nest",
|
||||
"timestamp": "2024-05-05T12:00:00.000Z",
|
||||
"rule": {
|
||||
"level": 2,
|
||||
"description": "Low severity event",
|
||||
"groups": ["low_sev"]
|
||||
},
|
||||
"agent": { "id": "004" },
|
||||
"data": { "info": "some data" }
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
let mcp_port = find_available_port();
|
||||
let mut mcp_server = start_mcp_server(&wazuh_host, wazuh_port, mcp_port);
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let mcp_client = McpClient::new(format!("http://localhost:{}", mcp_port));
|
||||
|
||||
let mcp_data = mcp_client.get_mcp_data().await?;
|
||||
assert_eq!(mcp_data.len(), 3);
|
||||
|
||||
let msg1 = &mcp_data[0];
|
||||
assert_eq!(msg1.context["id"], "missing_all");
|
||||
assert_eq!(msg1.context["category"], "unknown_category");
|
||||
assert_eq!(msg1.context["severity"], "unknown_severity");
|
||||
assert_eq!(msg1.context["description"], "");
|
||||
assert!(
|
||||
msg1.context["agent"].is_object() && msg1.context["agent"].as_object().unwrap().is_empty()
|
||||
);
|
||||
assert!(
|
||||
msg1.context["data"].is_object() && msg1.context["data"].as_object().unwrap().is_empty()
|
||||
);
|
||||
let ts1 = DateTime::parse_from_rfc3339(&msg1.timestamp)
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
assert!((Utc::now() - ts1).num_seconds() < 5);
|
||||
|
||||
let msg2 = &mcp_data[1];
|
||||
assert_eq!(msg2.context["id"], "missing_rule_fields");
|
||||
assert_eq!(msg2.context["category"], "unknown_category");
|
||||
assert_eq!(msg2.context["severity"], "unknown_severity");
|
||||
assert_eq!(msg2.context["description"], "");
|
||||
assert_eq!(msg2.context["agent"]["name"], "agent-minimal");
|
||||
assert!(
|
||||
msg2.context["data"].is_object() && msg2.context["data"].as_object().unwrap().is_empty()
|
||||
);
|
||||
assert_eq!(msg2.timestamp, "2024-05-05T11:00:00Z");
|
||||
|
||||
let msg3 = &mcp_data[2];
|
||||
assert_eq!(msg3.context["id"], "no_source_nest");
|
||||
assert_eq!(msg3.context["category"], "low_sev");
|
||||
assert_eq!(msg3.context["severity"], "low");
|
||||
assert_eq!(msg3.context["description"], "Low severity event");
|
||||
assert_eq!(msg3.context["agent"]["id"], "004");
|
||||
assert!(msg3.context["agent"].get("name").is_none());
|
||||
assert_eq!(msg3.context["data"]["info"], "some data");
|
||||
assert_eq!(msg3.timestamp, "2024-05-05T12:00:00Z");
|
||||
|
||||
mcp_server.kill().expect("Failed to kill MCP server");
|
||||
Ok(())
|
||||
}
|
@ -1,180 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpMessage {
|
||||
pub protocol_version: String,
|
||||
pub source: String,
|
||||
pub timestamp: String,
|
||||
pub event_type: String,
|
||||
pub context: Value,
|
||||
pub metadata: Value,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait McpClientTrait {
|
||||
async fn get_mcp_data(&self) -> Result<Vec<McpMessage>>;
|
||||
|
||||
async fn check_health(&self) -> Result<Value>;
|
||||
|
||||
async fn query_mcp_data(&self, filters: Value) -> Result<Vec<McpMessage>>;
|
||||
}
|
||||
|
||||
pub struct McpClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self { client, base_url }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClientTrait for McpClient {
|
||||
async fn get_mcp_data(&self) -> Result<Vec<McpMessage>> {
|
||||
let url = format!("{}/mcp", self.base_url);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::OK => {
|
||||
let data = response.json::<Vec<McpMessage>>().await?;
|
||||
Ok(data)
|
||||
}
|
||||
status => {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
anyhow::bail!("MCP request failed with status {}: {}", status, error_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_health(&self) -> Result<Value> {
|
||||
let url = format!("{}/health", self.base_url);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::OK => {
|
||||
let data = response.json::<Value>().await?;
|
||||
Ok(data)
|
||||
}
|
||||
status => {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
anyhow::bail!("Health check failed with status {}: {}", status, error_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_mcp_data(&self, filters: Value) -> Result<Vec<McpMessage>> {
|
||||
let url = format!("{}/mcp", self.base_url);
|
||||
let response = self.client.post(&url).json(&filters).send().await?;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::OK => {
|
||||
let data = response.json::<Vec<McpMessage>>().await?;
|
||||
Ok(data)
|
||||
}
|
||||
status => {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
anyhow::bail!("MCP query failed with status {}: {}", status, error_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
use tokio;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_get_data() {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/mcp");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!([
|
||||
{
|
||||
"protocol_version": "1.0",
|
||||
"source": "Wazuh",
|
||||
"timestamp": "2023-05-01T12:00:00Z",
|
||||
"event_type": "alert",
|
||||
"context": {
|
||||
"id": "12345",
|
||||
"category": "intrusion_detection",
|
||||
"severity": "high",
|
||||
"description": "Test alert",
|
||||
"data": {
|
||||
"source_ip": "192.168.1.100"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"integration": "Wazuh-MCP",
|
||||
"notes": "Test note"
|
||||
}
|
||||
}
|
||||
]));
|
||||
});
|
||||
|
||||
let client = McpClient::new(server.url(""));
|
||||
|
||||
let result = client.get_mcp_data().await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].protocol_version, "1.0");
|
||||
assert_eq!(result[0].source, "Wazuh");
|
||||
assert_eq!(result[0].event_type, "alert");
|
||||
|
||||
let context = &result[0].context;
|
||||
assert_eq!(context["id"], "12345");
|
||||
assert_eq!(context["category"], "intrusion_detection");
|
||||
assert_eq!(context["severity"], "high");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_client_health_check() {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(GET).path("/health");
|
||||
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"status": "ok",
|
||||
"service": "wazuh-mcp-server",
|
||||
"timestamp": "2023-05-01T12:00:00Z"
|
||||
}));
|
||||
});
|
||||
|
||||
let client = McpClient::new(server.url(""));
|
||||
|
||||
let result = client.check_health().await.unwrap();
|
||||
|
||||
assert_eq!(result["status"], "ok");
|
||||
assert_eq!(result["service"], "wazuh-mcp-server");
|
||||
}
|
||||
}
|
@ -1,164 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use clap::Parser;
|
||||
use serde_json::Value;
|
||||
use std::io::{self, Write}; // For stdout().flush() and stdin().read_line()
|
||||
|
||||
use mcp_server_wazuh::mcp::client::{McpClient, McpClientTrait};
|
||||
use serde::Deserialize; // For ParsedRequest
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(
|
||||
name = "mcp-client-cli",
|
||||
version = "0.1.0",
|
||||
about = "Interactive CLI for MCP server. Enter JSON-RPC requests, or 'health'/'quit'."
|
||||
)]
|
||||
struct CliArgs {
|
||||
#[clap(long, help = "Path to the MCP server executable for stdio mode.")]
|
||||
stdio_exe: Option<String>,
|
||||
|
||||
#[clap(
|
||||
long,
|
||||
env = "MCP_SERVER_URL",
|
||||
default_value = "http://localhost:8000",
|
||||
help = "URL of the MCP server for HTTP mode."
|
||||
)]
|
||||
http_url: String,
|
||||
}
|
||||
|
||||
// For parsing raw JSON request strings
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ParsedRequest {
|
||||
// jsonrpc: String, // Not strictly needed for sending
|
||||
method: String,
|
||||
params: Option<Value>,
|
||||
id: Value, // ID can be string or number
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli_args = CliArgs::parse();
|
||||
let mut client: McpClient;
|
||||
let is_stdio_mode = cli_args.stdio_exe.is_some();
|
||||
|
||||
if let Some(ref exe_path) = cli_args.stdio_exe {
|
||||
println!("Using stdio mode with executable: {}", exe_path);
|
||||
client = McpClient::new_stdio(exe_path, None).await?;
|
||||
println!("Sending 'initialize' request to stdio server...");
|
||||
match client.initialize().await {
|
||||
Ok(init_result) => {
|
||||
println!("Initialization successful:");
|
||||
println!(" Protocol Version: {}", init_result.protocol_version);
|
||||
println!(" Server Name: {}", init_result.server_info.name);
|
||||
println!(" Server Version: {}", init_result.server_info.version);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Stdio Initialization failed: {}. You may need to send a raw 'initialize' JSON-RPC request or check server logs.", e);
|
||||
// Allow continuing, user might want to send raw init or other commands.
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("Using HTTP mode with URL: {}", cli_args.http_url);
|
||||
client = McpClient::new_http(cli_args.http_url.clone());
|
||||
// No automatic initialize for HTTP mode as per McpClientTrait.
|
||||
// `initialize` is typically a stdio-specific concept in MCP.
|
||||
}
|
||||
|
||||
println!("\nInteractive MCP Client. Enter a JSON-RPC request, 'health' (HTTP only), or 'quit'.");
|
||||
println!("Press CTRL-D for EOF to exit.");
|
||||
|
||||
let mut input_buffer = String::new();
|
||||
loop {
|
||||
input_buffer.clear();
|
||||
print!("mcp> ");
|
||||
io::stdout().flush().map_err(|e| anyhow!("Failed to flush stdout: {}", e))?;
|
||||
|
||||
match io::stdin().read_line(&mut input_buffer) {
|
||||
Ok(0) => { // EOF (Ctrl-D)
|
||||
println!("\nEOF detected. Exiting.");
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let line = input_buffer.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.eq_ignore_ascii_case("quit") {
|
||||
println!("Exiting.");
|
||||
break;
|
||||
}
|
||||
|
||||
if line.eq_ignore_ascii_case("health") {
|
||||
if is_stdio_mode {
|
||||
println!("'health' command is intended for HTTP mode. For stdio, you would need to send a specific JSON-RPC request if the server supports a health method via stdio.");
|
||||
} else {
|
||||
println!("Checking server health (HTTP GET to /health)...");
|
||||
let health_url = format!("{}/health", cli_args.http_url); // Use the parsed http_url
|
||||
match reqwest::get(&health_url).await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
let response_text = response.text().await.unwrap_or_else(|_| "Failed to read response body".to_string());
|
||||
if status.is_success() {
|
||||
match serde_json::from_str::<Value>(&response_text) {
|
||||
Ok(json_val) => println!("Health response ({}):\n{}", status, serde_json::to_string_pretty(&json_val).unwrap_or_else(|_| response_text.clone())),
|
||||
Err(_) => println!("Health response ({}):\n{}", status, response_text),
|
||||
}
|
||||
} else {
|
||||
eprintln!("Health check failed with status: {}", status);
|
||||
eprintln!("Response: {}", response_text);
|
||||
}
|
||||
}
|
||||
Err(e) => eprintln!("Health check request failed: {}", e),
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Assume it's a JSON-RPC request
|
||||
println!("Attempting to send as JSON-RPC: {}", line);
|
||||
match serde_json::from_str::<ParsedRequest>(line) {
|
||||
Ok(parsed_req) => {
|
||||
match client
|
||||
.send_json_rpc_request(
|
||||
&parsed_req.method,
|
||||
parsed_req.params.clone(),
|
||||
parsed_req.id.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response_value) => {
|
||||
println!(
|
||||
"Server Response: {}",
|
||||
serde_json::to_string_pretty(&response_value).unwrap_or_else(
|
||||
|e_pretty| format!("Failed to pretty-print response ({}): {:?}", e_pretty, response_value)
|
||||
)
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error processing JSON-RPC request '{}': {}", line, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to parse input as a JSON-RPC request: {}. Input: '{}'", e, line);
|
||||
eprintln!("Please enter a valid JSON-RPC request string, 'health', or 'quit'.");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading input: {}. Exiting.", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_stdio_mode {
|
||||
println!("Sending 'shutdown' request to stdio server...");
|
||||
match client.shutdown().await {
|
||||
Ok(_) => println!("Shutdown command acknowledged by server."),
|
||||
Err(e) => eprintln!("Error during shutdown: {}. Server might have already exited or closed the connection.", e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
365
tests/mcp_stdio_test.rs
Normal file
365
tests/mcp_stdio_test.rs
Normal file
@ -0,0 +1,365 @@
|
||||
//! Tests for MCP protocol communication via stdio
|
||||
//!
|
||||
//! These tests verify the basic MCP protocol implementation without
|
||||
//! requiring a Wazuh connection.
|
||||
|
||||
use std::process::{Command, Stdio};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
struct McpStdioClient {
|
||||
child: std::process::Child,
|
||||
stdin: std::process::ChildStdin,
|
||||
stdout: BufReader<std::process::ChildStdout>,
|
||||
}
|
||||
|
||||
impl McpStdioClient {
|
||||
fn start() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut child = Command::new("cargo")
|
||||
.args(["run", "--bin", "mcp-server-wazuh"])
|
||||
.env("WAZUH_API_HOST", "nonexistent.example.com") // Use non-existent host
|
||||
.env("WAZUH_API_PORT", "9999")
|
||||
.env("WAZUH_API_USER", "test")
|
||||
.env("WAZUH_API_PASS", "test")
|
||||
.env("WAZUH_INDEXER_HOST", "nonexistent.example.com") // Use non-existent host
|
||||
.env("WAZUH_INDEXER_PORT", "8888")
|
||||
.env("WAZUH_INDEXER_USER", "test")
|
||||
.env("WAZUH_INDEXER_PASS", "test")
|
||||
.env("VERIFY_SSL", "false")
|
||||
.env("RUST_LOG", "error") // Minimize logging noise
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit()) // Changed from Stdio::null() to inherit stderr
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = BufReader::new(child.stdout.take().unwrap());
|
||||
|
||||
Ok(McpStdioClient {
|
||||
child,
|
||||
stdin,
|
||||
stdout,
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message(&mut self, message: &Value) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let message_str = serde_json::to_string(message)?;
|
||||
writeln!(self.stdin, "{}", message_str)?;
|
||||
self.stdin.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_response(&mut self) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
let mut line = String::new();
|
||||
self.stdout.read_line(&mut line)?;
|
||||
let response: Value = serde_json::from_str(&line.trim())?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn send_and_receive(&mut self, message: &Value) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
self.send_message(message)?;
|
||||
self.read_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for McpStdioClient {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.child.kill();
|
||||
let _ = self.child.wait();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_protocol_initialization() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut client = McpStdioClient::start()?;
|
||||
|
||||
// Give the server time to start
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Test initialize request
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test-client",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = client.send_and_receive(&init_request)?;
|
||||
|
||||
// Verify JSON-RPC 2.0 compliance
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 1);
|
||||
assert!(response["result"].is_object());
|
||||
assert!(response["error"].is_null());
|
||||
|
||||
// Verify MCP initialize response structure
|
||||
let result = &response["result"];
|
||||
assert_eq!(result["protocolVersion"], "2024-11-05");
|
||||
assert!(result["capabilities"].is_object());
|
||||
assert!(result["serverInfo"].is_object());
|
||||
|
||||
// Verify server info
|
||||
let server_info = &result["serverInfo"];
|
||||
assert!(server_info["name"].is_string());
|
||||
assert!(server_info["version"].is_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_tools_list_without_wazuh() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut client = McpStdioClient::start()?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize first
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
client.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
client.send_message(&initialized)?;
|
||||
|
||||
// Request tools list
|
||||
let tools_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = client.send_and_receive(&tools_request)?;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 2);
|
||||
assert!(response["result"].is_object());
|
||||
|
||||
let result = &response["result"];
|
||||
assert!(result["tools"].is_array());
|
||||
|
||||
let tools = result["tools"].as_array().unwrap();
|
||||
assert!(!tools.is_empty());
|
||||
|
||||
// Verify tool structure
|
||||
for tool in tools {
|
||||
assert!(tool["name"].is_string());
|
||||
assert!(tool["description"].is_string());
|
||||
assert!(tool["inputSchema"].is_object());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_json_rpc_request() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut client = McpStdioClient::start()?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// 1. Initialize the connection first
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.0.0"}
|
||||
}
|
||||
});
|
||||
let _init_response = client.send_and_receive(&init_request)?; // Read and ignore/assert init response
|
||||
// assert!(_init_response["result"].is_object()); // Optional: assert successful init
|
||||
|
||||
// 2. Send initialized notification
|
||||
let initialized_notification = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
client.send_message(&initialized_notification)?;
|
||||
|
||||
// 3. Send the invalid JSON-RPC request (missing required fields)
|
||||
let invalid_request = json!({
|
||||
// "jsonrpc": "2.0", // Missing jsonrpc field to make it invalid
|
||||
"id": 2, // Use a new ID
|
||||
"method": "some_method_that_might_not_exist"
|
||||
});
|
||||
client.send_message(&invalid_request)?;
|
||||
|
||||
// 4. The server currently closes the connection upon such an invalid request (see logs:
|
||||
// `ERROR rmcp::transport::io ... serde error ...` followed by `input stream terminated`).
|
||||
// Therefore, subsequent requests should fail. This test verifies this behavior.
|
||||
// Ideally, the server might send a JSON-RPC error and keep the connection open,
|
||||
// but that would require changes to the server's error handling logic.
|
||||
|
||||
// 5. Attempt to send a subsequent valid request.
|
||||
let list_tools_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3, // New ID
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let result = client.send_and_receive(&list_tools_request);
|
||||
|
||||
// Assert that the operation failed, indicating the connection was likely closed.
|
||||
assert!(result.is_err(), "Server should have closed the connection after the invalid request, leading to an error here.");
|
||||
|
||||
// Optionally, check the error type more specifically if needed, e.g., for EOF.
|
||||
if let Err(e) = result {
|
||||
let error_message = e.to_string().to_lowercase();
|
||||
assert!(
|
||||
error_message.contains("eof") || error_message.contains("broken pipe") || error_message.contains("connection reset"),
|
||||
"Expected EOF, broken pipe, or connection reset error, but got: {}", e
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unsupported_method() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut client = McpStdioClient::start()?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize first
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
client.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized_notification = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
client.send_message(&initialized_notification)?;
|
||||
|
||||
let unsupported_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
|
||||
"method": "unsupported/method"
|
||||
// Omitting "params": {} as it might be causing deserialization issues
|
||||
// in rmcp for unknown methods. The JSON-RPC spec allows params to be omitted.
|
||||
});
|
||||
|
||||
// Send the unsupported request. We don't expect a valid JSON-RPC response.
|
||||
// Instead, the server is likely to close the connection due to deserialization issues
|
||||
// in rmcp when encountering an unknown method, as it cannot match it to a known JsonRpcMessage variant.
|
||||
client.send_message(&unsupported_request)?;
|
||||
|
||||
// Attempt to send a subsequent valid request to confirm the connection was dropped.
|
||||
let list_tools_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3, // Use a new ID
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let result = client.send_and_receive(&list_tools_request);
|
||||
|
||||
// Assert that the operation failed, indicating the connection was likely closed.
|
||||
assert!(result.is_err(), "Server should have closed the connection after the unsupported method request, leading to an error here.");
|
||||
|
||||
// Optionally, check the error type more specifically if needed, e.g., for EOF.
|
||||
if let Err(e) = result {
|
||||
let error_message = e.to_string().to_lowercase();
|
||||
assert!(
|
||||
error_message.contains("eof") || error_message.contains("broken pipe") || error_message.contains("connection reset"),
|
||||
"Expected EOF, broken pipe, or connection reset error, but got: {}", e
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut client = McpStdioClient::start()?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
client.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
client.send_message(&initialized)?;
|
||||
|
||||
// Send multiple requests with different IDs
|
||||
let request1 = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 10,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let request2 = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 20,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
// Send both requests
|
||||
client.send_message(&request1)?;
|
||||
client.send_message(&request2)?;
|
||||
|
||||
// Read both responses
|
||||
let response1 = client.read_response()?;
|
||||
let response2 = client.read_response()?;
|
||||
|
||||
// Responses should maintain request IDs (though order might vary)
|
||||
let ids: Vec<i64> = vec![
|
||||
response1["id"].as_i64().unwrap(),
|
||||
response2["id"].as_i64().unwrap(),
|
||||
];
|
||||
|
||||
assert!(ids.contains(&10));
|
||||
assert!(ids.contains(&20));
|
||||
|
||||
Ok(())
|
||||
}
|
340
tests/mock_wazuh_server.rs
Normal file
340
tests/mock_wazuh_server.rs
Normal file
@ -0,0 +1,340 @@
|
||||
//! Mock Wazuh API server for testing
|
||||
//!
|
||||
//! This module provides a configurable mock server that simulates the Wazuh Indexer API
|
||||
//! for testing purposes. It supports various response scenarios including success,
|
||||
//! empty results, and error conditions.
|
||||
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
|
||||
pub struct MockWazuhServer {
|
||||
server: MockServer,
|
||||
}
|
||||
|
||||
impl MockWazuhServer {
|
||||
pub fn new() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
// Setup default authentication endpoint
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token.eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap());
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(Self::sample_alerts_response());
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
pub fn with_empty_alerts() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap());
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": []
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
pub fn with_auth_error() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(401)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "Invalid credentials"
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
pub fn with_alerts_error() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap());
|
||||
then.status(500)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "Internal server error"
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
pub fn with_malformed_alerts() -> Self {
|
||||
let server = MockServer::start();
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST).path("/security/user/authenticate");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"jwt": "mock.jwt.token"
|
||||
}));
|
||||
});
|
||||
|
||||
server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path_matches(Regex::new(r"/wazuh-alerts.*/_search").unwrap());
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "missing_fields",
|
||||
"timestamp": "invalid-date-format"
|
||||
// Missing rule, agent, etc.
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "partial_data",
|
||||
"timestamp": "2024-01-15T10:30:45.123Z",
|
||||
"rule": {
|
||||
"level": 5
|
||||
// Missing description
|
||||
},
|
||||
"agent": {
|
||||
"id": "001"
|
||||
// Missing name
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
Self { server }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> String {
|
||||
self.server.url("")
|
||||
}
|
||||
|
||||
pub fn host(&self) -> String {
|
||||
let url = self.url();
|
||||
let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect();
|
||||
parts[0].to_string()
|
||||
}
|
||||
|
||||
pub fn port(&self) -> u16 {
|
||||
let url = self.url();
|
||||
let parts: Vec<&str> = url.trim_start_matches("http://").split(':').collect();
|
||||
parts[1].parse().unwrap()
|
||||
}
|
||||
|
||||
fn sample_alerts_response() -> serde_json::Value {
|
||||
json!({
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"id": "1747091815.1212763",
|
||||
"timestamp": "2024-01-15T10:30:45.123Z",
|
||||
"rule": {
|
||||
"level": 7,
|
||||
"description": "Attached USB Storage",
|
||||
"groups": ["usb", "pci_dss"]
|
||||
},
|
||||
"agent": {
|
||||
"id": "001",
|
||||
"name": "web-server-01"
|
||||
},
|
||||
"data": {
|
||||
"device": "/dev/sdb1",
|
||||
"mount_point": "/media/usb"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "1747066333.1207112",
|
||||
"timestamp": "2024-01-15T10:25:12.456Z",
|
||||
"rule": {
|
||||
"level": 5,
|
||||
"description": "New dpkg (Debian Package) installed.",
|
||||
"groups": ["package_management", "debian"]
|
||||
},
|
||||
"agent": {
|
||||
"id": "002",
|
||||
"name": "database-server"
|
||||
},
|
||||
"data": {
|
||||
"package": "nginx",
|
||||
"version": "1.18.0-6ubuntu14.4"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"id": "1747055444.1205998",
|
||||
"timestamp": "2024-01-15T10:20:33.789Z",
|
||||
"rule": {
|
||||
"level": 12,
|
||||
"description": "Multiple authentication failures",
|
||||
"groups": ["authentication_failed", "pci_dss"]
|
||||
},
|
||||
"agent": {
|
||||
"id": "003",
|
||||
"name": "ssh-gateway"
|
||||
},
|
||||
"data": {
|
||||
"source_ip": "192.168.1.100",
|
||||
"user": "admin",
|
||||
"attempts": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockWazuhServer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_creation() {
|
||||
let mock_server = MockWazuhServer::new();
|
||||
assert!(!mock_server.url().is_empty());
|
||||
assert!(!mock_server.host().is_empty());
|
||||
assert!(mock_server.port() > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_auth_endpoint() {
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/security/user/authenticate", mock_server.url()))
|
||||
.json(&json!({"username": "admin", "password": "admin"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
let body: serde_json::Value = response.json().await.unwrap();
|
||||
assert!(body.get("jwt").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_alerts_endpoint() {
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/wazuh-alerts*/_search", mock_server.url()))
|
||||
.json(&json!({"query": {"match_all": {}}}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
let body: serde_json::Value = response.json().await.unwrap();
|
||||
assert!(body.get("hits").is_some());
|
||||
let hits = body["hits"]["hits"].as_array().unwrap();
|
||||
assert!(!hits.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_alerts_server() {
|
||||
let mock_server = MockWazuhServer::with_empty_alerts();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/wazuh-alerts*/_search", mock_server.url()))
|
||||
.json(&json!({"query": {"match_all": {}}}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
let body: serde_json::Value = response.json().await.unwrap();
|
||||
let hits = body["hits"]["hits"].as_array().unwrap();
|
||||
assert!(hits.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_error_server() {
|
||||
let mock_server = MockWazuhServer::with_auth_error();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/security/user/authenticate", mock_server.url()))
|
||||
.json(&json!({"username": "admin", "password": "wrong"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 401);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_alerts_error_server() {
|
||||
let mock_server = MockWazuhServer::with_alerts_error();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/wazuh-alerts*/_search", mock_server.url()))
|
||||
.json(&json!({"query": {"match_all": {}}}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), 500);
|
||||
}
|
||||
}
|
546
tests/rmcp_integration_test.rs
Normal file
546
tests/rmcp_integration_test.rs
Normal file
@ -0,0 +1,546 @@
|
||||
//! Integration tests for the rmcp-based Wazuh MCP Server
|
||||
//!
|
||||
//! These tests verify the MCP server functionality using a mock Wazuh API server.
|
||||
//! Tests cover tool registration, parameter validation, alert retrieval, and error handling.
|
||||
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use serde_json::{json, Value};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Mutex;
|
||||
|
||||
mod mock_wazuh_server;
|
||||
use mock_wazuh_server::MockWazuhServer;
|
||||
|
||||
static TEST_MUTEX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
|
||||
|
||||
struct McpServerProcess {
|
||||
child: Child,
|
||||
stdin: std::process::ChildStdin,
|
||||
stdout: BufReader<std::process::ChildStdout>,
|
||||
}
|
||||
|
||||
impl McpServerProcess {
|
||||
fn start_with_mock_wazuh(mock_server: &MockWazuhServer) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let mut child = Command::new("cargo")
|
||||
.args(["run", "--bin", "mcp-server-wazuh"])
|
||||
.env("WAZUH_INDEXER_HOST", mock_server.host())
|
||||
.env("WAZUH_INDEXER_PORT", mock_server.port().to_string())
|
||||
.env("WAZUH_INDEXER_USER", "admin")
|
||||
.env("WAZUH_INDEXER_PASS", "admin")
|
||||
.env("VERIFY_SSL", "false")
|
||||
.env("WAZUH_TEST_PROTOCOL", "http")
|
||||
.env("RUST_LOG", "warn") // Reduce noise in tests
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit()) // Inherit stderr to see server logs
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = BufReader::new(child.stdout.take().unwrap());
|
||||
|
||||
Ok(McpServerProcess {
|
||||
child,
|
||||
stdin,
|
||||
stdout,
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message(&mut self, message: &Value) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let message_str = serde_json::to_string(message)?;
|
||||
writeln!(self.stdin, "{}", message_str)?;
|
||||
self.stdin.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_response(&mut self) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
let mut line = String::new();
|
||||
self.stdout.read_line(&mut line)?;
|
||||
let response: Value = serde_json::from_str(line.trim())?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn send_and_receive(&mut self, message: &Value) -> Result<Value, Box<dyn std::error::Error>> {
|
||||
self.send_message(message)?;
|
||||
self.read_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for McpServerProcess {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.child.kill();
|
||||
let _ = self.child.wait();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_server_initialization() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
// Give the server time to start
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Send initialize request
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test-client",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 1);
|
||||
assert!(response["result"].is_object());
|
||||
|
||||
let result = &response["result"];
|
||||
assert_eq!(result["protocolVersion"], "2024-11-05");
|
||||
assert!(result["capabilities"].is_object());
|
||||
assert!(result["serverInfo"].is_object());
|
||||
assert!(result["instructions"].is_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tools_list() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize first
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Request tools list
|
||||
let tools_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tools_request)?;
|
||||
|
||||
// Verify response
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 2);
|
||||
assert!(response["result"]["tools"].is_array());
|
||||
|
||||
let tools = response["result"]["tools"].as_array().unwrap();
|
||||
assert!(!tools.is_empty());
|
||||
|
||||
// Check for our Wazuh alert summary tool
|
||||
let alert_tool = tools.iter()
|
||||
.find(|tool| tool["name"] == "get_wazuh_alert_summary")
|
||||
.expect("get_wazuh_alert_summary tool should be present");
|
||||
|
||||
assert!(alert_tool["description"].is_string());
|
||||
assert!(alert_tool["inputSchema"].is_object());
|
||||
|
||||
// Verify input schema structure
|
||||
let input_schema = &alert_tool["inputSchema"];
|
||||
assert_eq!(input_schema["type"], "object");
|
||||
assert!(input_schema["properties"].is_object());
|
||||
assert!(input_schema["properties"]["limit"].is_object());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_wazuh_alert_summary_success() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Call the tool
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {
|
||||
"limit": 2
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
assert!(response["result"].is_object());
|
||||
|
||||
let result = &response["result"];
|
||||
assert!(result["content"].is_array());
|
||||
assert_eq!(result["isError"], false);
|
||||
|
||||
let content = result["content"].as_array().unwrap();
|
||||
assert!(!content.is_empty());
|
||||
|
||||
// Verify content format
|
||||
for item in content {
|
||||
assert_eq!(item["type"], "text");
|
||||
assert!(item["text"].is_string());
|
||||
|
||||
let text = item["text"].as_str().unwrap();
|
||||
assert!(text.contains("Alert ID:"));
|
||||
assert!(text.contains("Time:"));
|
||||
assert!(text.contains("Agent:"));
|
||||
assert!(text.contains("Level:"));
|
||||
assert!(text.contains("Description:"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_wazuh_alert_summary_empty_results() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::with_empty_alerts();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Call the tool
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Verify response
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
|
||||
let result = &response["result"];
|
||||
assert!(result["content"].is_array());
|
||||
assert_eq!(result["isError"], false);
|
||||
|
||||
let content = result["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 1);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "No Wazuh alerts found.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_wazuh_alert_summary_api_error() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::with_alerts_error();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Call the tool
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {
|
||||
"limit": 5
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Verify error response
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
|
||||
let result = &response["result"];
|
||||
assert!(result["content"].is_array());
|
||||
assert_eq!(result["isError"], true);
|
||||
|
||||
let content = result["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 1);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
|
||||
let error_text = content[0]["text"].as_str().unwrap();
|
||||
assert!(error_text.contains("Error retrieving alerts from Wazuh"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_tool_call() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Call non-existent tool
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "non_existent_tool",
|
||||
"arguments": {}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Should get an error response
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
assert!(response["error"].is_object());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parameter_validation() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::new();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Test with invalid parameter type (string instead of number)
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {
|
||||
"limit": "invalid"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Should get an error response for invalid parameters
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
// The response might be an error or a successful response with error content
|
||||
// depending on how rmcp handles parameter validation
|
||||
assert!(response["error"].is_object() ||
|
||||
(response["result"]["isError"] == true));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_malformed_alert_data_handling() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = TEST_MUTEX.lock().unwrap();
|
||||
|
||||
let mock_server = MockWazuhServer::with_malformed_alerts();
|
||||
let mut mcp_server = McpServerProcess::start_with_mock_wazuh(&mock_server)?;
|
||||
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test", "version": "1.0"}
|
||||
}
|
||||
});
|
||||
mcp_server.send_and_receive(&init_request)?;
|
||||
|
||||
// Send initialized notification
|
||||
let initialized = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
mcp_server.send_message(&initialized)?;
|
||||
|
||||
// Call the tool
|
||||
let tool_call = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "get_wazuh_alert_summary",
|
||||
"arguments": {
|
||||
"limit": 5
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = mcp_server.send_and_receive(&tool_call)?;
|
||||
|
||||
// Should handle malformed data gracefully
|
||||
assert_eq!(response["jsonrpc"], "2.0");
|
||||
assert_eq!(response["id"], 3);
|
||||
|
||||
let result = &response["result"];
|
||||
assert!(result["content"].is_array());
|
||||
// Should not error out, but handle missing fields gracefully
|
||||
assert_eq!(result["isError"], false);
|
||||
|
||||
let content = result["content"].as_array().unwrap();
|
||||
assert!(!content.is_empty());
|
||||
|
||||
// Verify that missing fields are handled with defaults
|
||||
for item in content {
|
||||
assert_eq!(item["type"], "text");
|
||||
let text = item["text"].as_str().unwrap();
|
||||
// Should contain default values for missing fields
|
||||
assert!(text.contains("Alert ID:"));
|
||||
assert!(text.contains("Unknown") || text.contains("missing_fields") || text.contains("partial_data"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,47 +1,67 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Running all tests..."
|
||||
cargo test
|
||||
# Test script for Wazuh MCP Server (rmcp-based)
|
||||
# This script runs various tests to ensure the server is working correctly
|
||||
|
||||
echo "Building MCP client CLI..."
|
||||
cargo build --bin mcp_client_cli
|
||||
set -e
|
||||
|
||||
echo "Building main server binary for stdio CLI tests..."
|
||||
# Build the server executable that mcp_client_cli will run
|
||||
cargo build --bin mcp-server-wazuh # Output: target/debug/mcp-server-wazuh
|
||||
echo "Starting Wazuh MCP Server tests (rmcp-based)..."
|
||||
|
||||
echo "Testing MCP client CLI in stdio mode..."
|
||||
# Set test environment variables
|
||||
export RUST_LOG=info
|
||||
|
||||
echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh initialize"
|
||||
./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh initialize
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "CLI 'initialize' command failed!"
|
||||
exit 1
|
||||
fi
|
||||
echo "Environment variables set:"
|
||||
echo " RUST_LOG: $RUST_LOG"
|
||||
|
||||
echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext"
|
||||
./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "CLI 'provideContext' command failed!"
|
||||
exit 1
|
||||
fi
|
||||
# Function to cleanup background processes
|
||||
cleanup() {
|
||||
echo "Cleaning up..."
|
||||
if [ ! -z "$SERVER_PID" ]; then
|
||||
kill $SERVER_PID 2>/dev/null || true
|
||||
wait $SERVER_PID 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
# Example of provideContext with empty JSON params (optional to uncomment and test)
|
||||
# echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext '{}'"
|
||||
# ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh provideContext '{}'
|
||||
# if [ $? -ne 0 ]; then
|
||||
# echo "CLI 'provideContext {}' command failed!"
|
||||
# exit 1
|
||||
# fi
|
||||
# Set trap to cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
echo "Executing: ./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh shutdown"
|
||||
./target/debug/mcp_client_cli --stdio-exe ./target/debug/mcp-server-wazuh shutdown
|
||||
if [ $? -ne 0 ]; then
|
||||
# Shutdown might return an error if the server closes the pipe before the client fully processes the response,
|
||||
# but the primary goal is that the server process is terminated.
|
||||
# For this script, we'll be lenient on shutdown's exit code for now,
|
||||
# as long as initialize and provideContext worked.
|
||||
echo "CLI 'shutdown' command executed (non-zero exit code is sometimes expected if server closes pipe quickly)."
|
||||
fi
|
||||
echo ""
|
||||
echo "=== Running Unit Tests ==="
|
||||
cargo test --lib
|
||||
|
||||
echo "MCP client CLI stdio tests completed."
|
||||
echo ""
|
||||
echo "=== Running MCP Protocol Tests ==="
|
||||
cargo test --test mcp_stdio_test
|
||||
|
||||
echo ""
|
||||
echo "=== Running Integration Tests with Mock Wazuh ==="
|
||||
cargo test --test rmcp_integration_test
|
||||
|
||||
echo ""
|
||||
echo "=== Testing Server Binary ==="
|
||||
echo "Verifying server binary can start and show help..."
|
||||
|
||||
# Test that the binary can start and show help. It should exit immediately.
|
||||
cargo run --bin mcp-server-wazuh -- --help > /dev/null
|
||||
|
||||
echo "Help command test completed"
|
||||
|
||||
echo ""
|
||||
echo "=== All Tests Complete ==="
|
||||
echo ""
|
||||
echo "Test Summary:"
|
||||
echo "✓ Unit tests for library components"
|
||||
echo "✓ Wazuh client tests with mock HTTP server"
|
||||
echo "✓ MCP protocol tests via stdio"
|
||||
echo "✓ Integration tests with mock Wazuh API"
|
||||
echo "✓ Server binary startup test"
|
||||
echo ""
|
||||
echo "To test manually with a real Wazuh instance:"
|
||||
echo " export WAZUH_HOST=your-wazuh-host"
|
||||
echo " export WAZUH_PORT=9200"
|
||||
echo " export WAZUH_USER=admin"
|
||||
echo " export WAZUH_PASS=your-password"
|
||||
echo " cargo run --bin mcp-server-wazuh"
|
||||
echo ""
|
||||
echo "Then send MCP commands via stdin, for example:"
|
||||
echo ' echo '"'"'{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}'"'"' | cargo run --bin mcp-server-wazuh'
|
||||
|
Loading…
Reference in New Issue
Block a user