🎨 Formatted codebase

This commit is contained in:
Wouter de Bruijn 2025-06-17 09:21:59 +02:00
parent a57b51870f
commit 371e74fca8
No known key found for this signature in database
GPG Key ID: AC71F96733B92BFA
13 changed files with 394 additions and 317 deletions

View File

@ -1,10 +1,11 @@
""" """
Module for parsing configuration from the top level config.py file Module for parsing configuration from the top level config.py file
""" """
from pathlib import Path
from importlib import util from importlib import util
from os import environ, path
from logging import getLogger from logging import getLogger
from os import environ, path
from pathlib import Path
logger = getLogger(__name__) logger = getLogger(__name__)
@ -44,40 +45,40 @@ DEFAULT_CONFIG = {
"serial": "serialno_a", "serial": "serialno_a",
"device_type/model": "type", "device_type/model": "type",
"device_type/manufacturer/name": "vendor", "device_type/manufacturer/name": "vendor",
"oob_ip/address": "oob_ip" "oob_ip/address": "oob_ip",
}, },
"vm_inventory_map": { "vm_inventory_map": {
"status/label": "deployment_status", "status/label": "deployment_status",
"comments": "notes", "comments": "notes",
"name": "name" "name": "name",
}, },
"usermacro_sync": False, "usermacro_sync": False,
"device_usermacro_map": { "device_usermacro_map": {
"serial": "{$HW_SERIAL}", "serial": "{$HW_SERIAL}",
"role/name": "{$DEV_ROLE}", "role/name": "{$DEV_ROLE}",
"url": "{$NB_URL}", "url": "{$NB_URL}",
"id": "{$NB_ID}" "id": "{$NB_ID}",
}, },
"vm_usermacro_map": { "vm_usermacro_map": {
"memory": "{$TOTAL_MEMORY}", "memory": "{$TOTAL_MEMORY}",
"role/name": "{$DEV_ROLE}", "role/name": "{$DEV_ROLE}",
"url": "{$NB_URL}", "url": "{$NB_URL}",
"id": "{$NB_ID}" "id": "{$NB_ID}",
}, },
"tag_sync": False, "tag_sync": False,
"tag_lower": True, "tag_lower": True,
"tag_name": 'NetBox', "tag_name": "NetBox",
"tag_value": "name", "tag_value": "name",
"device_tag_map": { "device_tag_map": {
"site/name": "site", "site/name": "site",
"rack/name": "rack", "rack/name": "rack",
"platform/name": "target" "platform/name": "target",
}, },
"vm_tag_map": { "vm_tag_map": {
"site/name": "site", "site/name": "site",
"cluster/name": "cluster", "cluster/name": "cluster",
"platform/name": "target" "platform/name": "target",
} },
} }

View File

@ -5,12 +5,13 @@ Device specific handeling for NetBox to Zabbix
from copy import deepcopy from copy import deepcopy
from logging import getLogger from logging import getLogger
from re import search
from operator import itemgetter from operator import itemgetter
from re import search
from zabbix_utils import APIRequestError
from pynetbox import RequestError as NetboxRequestError from pynetbox import RequestError as NetboxRequestError
from zabbix_utils import APIRequestError
from modules.config import load_config
from modules.exceptions import ( from modules.exceptions import (
InterfaceConfigError, InterfaceConfigError,
SyncExternalError, SyncExternalError,
@ -22,10 +23,10 @@ from modules.interface import ZabbixInterface
from modules.tags import ZabbixTags from modules.tags import ZabbixTags
from modules.tools import field_mapper, remove_duplicates, sanatize_log_output from modules.tools import field_mapper, remove_duplicates, sanatize_log_output
from modules.usermacros import ZabbixUsermacros from modules.usermacros import ZabbixUsermacros
from modules.config import load_config
config = load_config() config = load_config()
class PhysicalDevice: class PhysicalDevice:
# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-positional-arguments # pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-positional-arguments
""" """
@ -125,8 +126,8 @@ class PhysicalDevice:
self.nb, self.nb,
self.nb_api_version, self.nb_api_version,
logger=self.logger, logger=self.logger,
nested_sitegroup_flag=config['traverse_site_groups'], nested_sitegroup_flag=config["traverse_site_groups"],
nested_region_flag=config['traverse_regions'], nested_region_flag=config["traverse_regions"],
nb_groups=nb_site_groups, nb_groups=nb_site_groups,
nb_regions=nb_regions, nb_regions=nb_regions,
) )
@ -177,8 +178,6 @@ class PhysicalDevice:
self.logger.warning(e) self.logger.warning(e)
raise TemplateError(e) raise TemplateError(e)
def get_templates_context(self): def get_templates_context(self):
"""Get Zabbix templates from the device context""" """Get Zabbix templates from the device context"""
if "zabbix" not in self.config_context: if "zabbix" not in self.config_context:
@ -203,9 +202,11 @@ class PhysicalDevice:
# Set inventory mode. Default is disabled (see class init function). # Set inventory mode. Default is disabled (see class init function).
if config["inventory_mode"] == "disabled": if config["inventory_mode"] == "disabled":
if config["inventory_sync"]: if config["inventory_sync"]:
self.logger.error(f"Host {self.name}: Unable to map NetBox inventory to Zabbix. " self.logger.error(
"Inventory sync is enabled in " f"Host {self.name}: Unable to map NetBox inventory to Zabbix. "
"config but inventory mode is disabled.") "Inventory sync is enabled in "
"config but inventory mode is disabled."
)
return True return True
if config["inventory_mode"] == "manual": if config["inventory_mode"] == "manual":
self.inventory_mode = 0 self.inventory_mode = 0
@ -403,7 +404,7 @@ class PhysicalDevice:
macros = ZabbixUsermacros( macros = ZabbixUsermacros(
self.nb, self.nb,
self._usermacro_map(), self._usermacro_map(),
config['usermacro_sync'], config["usermacro_sync"],
logger=self.logger, logger=self.logger,
host=self.name, host=self.name,
) )
@ -421,10 +422,10 @@ class PhysicalDevice:
tags = ZabbixTags( tags = ZabbixTags(
self.nb, self.nb,
self._tag_map(), self._tag_map(),
config['tag_sync'], config["tag_sync"],
config['tag_lower'], config["tag_lower"],
tag_name=config['tag_name'], tag_name=config["tag_name"],
tag_value=config['tag_value'], tag_value=config["tag_value"],
logger=self.logger, logger=self.logger,
host=self.name, host=self.name,
) )
@ -604,7 +605,9 @@ class PhysicalDevice:
) )
self.logger.error(e) self.logger.error(e)
raise SyncExternalError(e) from None raise SyncExternalError(e) from None
self.logger.info(f"Host {self.name}: updated with data {sanatize_log_output(kwargs)}.") self.logger.info(
f"Host {self.name}: updated with data {sanatize_log_output(kwargs)}."
)
self.create_journal_entry("info", "Updated host in Zabbix with latest NB data.") self.create_journal_entry("info", "Updated host in Zabbix with latest NB data.")
def ConsistencyCheck( def ConsistencyCheck(
@ -615,7 +618,7 @@ class PhysicalDevice:
Checks if Zabbix object is still valid with NetBox parameters. Checks if Zabbix object is still valid with NetBox parameters.
""" """
# If group is found or if the hostgroup is nested # If group is found or if the hostgroup is nested
if not self.setZabbixGroupID(groups): # or len(self.hostgroups.split("/")) > 1: if not self.setZabbixGroupID(groups): # or len(self.hostgroups.split("/")) > 1:
if create_hostgroups: if create_hostgroups:
# Script is allowed to create a new hostgroup # Script is allowed to create a new hostgroup
new_groups = self.createZabbixHostgroup(groups) new_groups = self.createZabbixHostgroup(groups)
@ -632,7 +635,7 @@ class PhysicalDevice:
) )
self.logger.warning(e) self.logger.warning(e)
raise SyncInventoryError(e) raise SyncInventoryError(e)
#if self.group_ids: # if self.group_ids:
# self.group_ids.append(self.pri_group_id) # self.group_ids.append(self.pri_group_id)
# Prepare templates and proxy config # Prepare templates and proxy config
@ -704,8 +707,9 @@ class PhysicalDevice:
if str(self.zabbix.version).startswith(("6", "5")): if str(self.zabbix.version).startswith(("6", "5")):
group_dictname = "groups" group_dictname = "groups"
# Check if hostgroups match # Check if hostgroups match
if (sorted(host[group_dictname], key=itemgetter('groupid')) == if sorted(host[group_dictname], key=itemgetter("groupid")) == sorted(
sorted(self.group_ids, key=itemgetter('groupid'))): self.group_ids, key=itemgetter("groupid")
):
self.logger.debug(f"Host {self.name}: hostgroups in-sync.") self.logger.debug(f"Host {self.name}: hostgroups in-sync.")
else: else:
self.logger.warning(f"Host {self.name}: hostgroups OUT of sync.") self.logger.warning(f"Host {self.name}: hostgroups OUT of sync.")
@ -720,8 +724,10 @@ class PhysicalDevice:
# Check if a proxy has been defined # Check if a proxy has been defined
if self.zbxproxy: if self.zbxproxy:
# Check if proxy or proxy group is defined # Check if proxy or proxy group is defined
if (self.zbxproxy["idtype"] in host and if (
host[self.zbxproxy["idtype"]] == self.zbxproxy["id"]): self.zbxproxy["idtype"] in host
and host[self.zbxproxy["idtype"]] == self.zbxproxy["id"]
):
self.logger.debug(f"Host {self.name}: proxy in-sync.") self.logger.debug(f"Host {self.name}: proxy in-sync.")
# Backwards compatibility for Zabbix <= 6 # Backwards compatibility for Zabbix <= 6
elif "proxy_hostid" in host and host["proxy_hostid"] == self.zbxproxy["id"]: elif "proxy_hostid" in host and host["proxy_hostid"] == self.zbxproxy["id"]:
@ -788,21 +794,23 @@ class PhysicalDevice:
self.updateZabbixHost(inventory=self.inventory) self.updateZabbixHost(inventory=self.inventory)
# Check host usermacros # Check host usermacros
if config['usermacro_sync']: if config["usermacro_sync"]:
# Make a full copy synce we dont want to lose the original value # Make a full copy synce we dont want to lose the original value
# of secret type macros from Netbox # of secret type macros from Netbox
netbox_macros = deepcopy(self.usermacros) netbox_macros = deepcopy(self.usermacros)
# Set the sync bit # Set the sync bit
full_sync_bit = bool(str(config['usermacro_sync']).lower() == "full") full_sync_bit = bool(str(config["usermacro_sync"]).lower() == "full")
for macro in netbox_macros: for macro in netbox_macros:
# If the Macro is a secret and full sync is NOT activated # If the Macro is a secret and full sync is NOT activated
if macro["type"] == str(1) and not full_sync_bit: if macro["type"] == str(1) and not full_sync_bit:
# Remove the value as the Zabbix api does not return the value key # Remove the value as the Zabbix api does not return the value key
# This is required when you want to do a diff between both lists # This is required when you want to do a diff between both lists
macro.pop("value") macro.pop("value")
# Sort all lists # Sort all lists
def filter_with_macros(macro): def filter_with_macros(macro):
return macro["macro"] return macro["macro"]
host["macros"].sort(key=filter_with_macros) host["macros"].sort(key=filter_with_macros)
netbox_macros.sort(key=filter_with_macros) netbox_macros.sort(key=filter_with_macros)
# Check if both lists are the same # Check if both lists are the same
@ -814,7 +822,7 @@ class PhysicalDevice:
self.updateZabbixHost(macros=self.usermacros) self.updateZabbixHost(macros=self.usermacros)
# Check host tags # Check host tags
if config['tag_sync']: if config["tag_sync"]:
if remove_duplicates(host["tags"], sortkey="tag") == self.tags: if remove_duplicates(host["tags"], sortkey="tag") == self.tags:
self.logger.debug(f"Host {self.name}: tags in-sync.") self.logger.debug(f"Host {self.name}: tags in-sync.")
else: else:
@ -870,8 +878,10 @@ class PhysicalDevice:
try: try:
# API call to Zabbix # API call to Zabbix
self.zabbix.hostinterface.update(updates) self.zabbix.hostinterface.update(updates)
e = (f"Host {self.name}: updated interface " e = (
f"with data {sanatize_log_output(updates)}.") f"Host {self.name}: updated interface "
f"with data {sanatize_log_output(updates)}."
)
self.logger.info(e) self.logger.info(e)
self.create_journal_entry("info", e) self.create_journal_entry("info", e)
except APIRequestError as e: except APIRequestError as e:

View File

@ -1,6 +1,8 @@
"""A collection of tools used by several classes""" """A collection of tools used by several classes"""
from modules.exceptions import HostgroupError from modules.exceptions import HostgroupError
def convert_recordset(recordset): def convert_recordset(recordset):
"""Converts netbox RedcordSet to list of dicts.""" """Converts netbox RedcordSet to list of dicts."""
recordlist = [] recordlist = []
@ -101,7 +103,9 @@ def remove_duplicates(input_list, sortkey=None):
return output_list return output_list
def verify_hg_format(hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", logger=None): def verify_hg_format(
hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", logger=None
):
""" """
Verifies hostgroup field format Verifies hostgroup field format
""" """
@ -109,44 +113,51 @@ def verify_hg_format(hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", log
device_cfs = [] device_cfs = []
if not vm_cfs: if not vm_cfs:
vm_cfs = [] vm_cfs = []
allowed_objects = {"dev": ["location", allowed_objects = {
"rack", "dev": [
"role", "location",
"manufacturer", "rack",
"region", "role",
"site", "manufacturer",
"site_group", "region",
"tenant", "site",
"tenant_group", "site_group",
"platform", "tenant",
"cluster"] "tenant_group",
,"vm": ["cluster_type", "platform",
"role", "cluster",
"manufacturer", ],
"region", "vm": [
"site", "cluster_type",
"site_group", "role",
"tenant", "manufacturer",
"tenant_group", "region",
"cluster", "site",
"device", "site_group",
"platform"] "tenant",
,"cfs": {"dev": [], "vm": []} "tenant_group",
} "cluster",
"device",
"platform",
],
"cfs": {"dev": [], "vm": []},
}
for cf in device_cfs: for cf in device_cfs:
allowed_objects['cfs']['dev'].append(cf.name) allowed_objects["cfs"]["dev"].append(cf.name)
for cf in vm_cfs: for cf in vm_cfs:
allowed_objects['cfs']['vm'].append(cf.name) allowed_objects["cfs"]["vm"].append(cf.name)
hg_objects = [] hg_objects = []
if isinstance(hg_format,list): if isinstance(hg_format, list):
for f in hg_format: for f in hg_format:
hg_objects = hg_objects + f.split("/") hg_objects = hg_objects + f.split("/")
else: else:
hg_objects = hg_format.split("/") hg_objects = hg_format.split("/")
hg_objects = sorted(set(hg_objects)) hg_objects = sorted(set(hg_objects))
for hg_object in hg_objects: for hg_object in hg_objects:
if (hg_object not in allowed_objects[hg_type] and if (
hg_object not in allowed_objects['cfs'][hg_type]): hg_object not in allowed_objects[hg_type]
and hg_object not in allowed_objects["cfs"][hg_type]
):
e = ( e = (
f"Hostgroup item {hg_object} is not valid. Make sure you" f"Hostgroup item {hg_object} is not valid. Make sure you"
" use valid items and separate them with '/'." " use valid items and separate them with '/'."

View File

@ -57,8 +57,10 @@ class ZabbixUsermacros:
macro["macro"] = str(macro_name) macro["macro"] = str(macro_name)
if isinstance(macro_properties, dict): if isinstance(macro_properties, dict):
if not "value" in macro_properties: if not "value" in macro_properties:
self.logger.warning(f"Host {self.name}: Usermacro {macro_name} has " self.logger.warning(
"no value in Netbox, skipping.") f"Host {self.name}: Usermacro {macro_name} has "
"no value in Netbox, skipping."
)
return False return False
macro["value"] = macro_properties["value"] macro["value"] = macro_properties["value"]
@ -83,8 +85,10 @@ class ZabbixUsermacros:
macro["description"] = "" macro["description"] = ""
else: else:
self.logger.warning(f"Host {self.name}: Usermacro {macro_name} " self.logger.warning(
"has no value, skipping.") f"Host {self.name}: Usermacro {macro_name} "
"has no value, skipping."
)
return False return False
else: else:
self.logger.error( self.logger.error(

View File

@ -1,10 +1,11 @@
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
"""Module that hosts all functions for virtual machine processing""" """Module that hosts all functions for virtual machine processing"""
from modules.config import load_config
from modules.device import PhysicalDevice from modules.device import PhysicalDevice
from modules.exceptions import InterfaceConfigError, SyncInventoryError, TemplateError from modules.exceptions import InterfaceConfigError, SyncInventoryError, TemplateError
from modules.hostgroups import Hostgroup from modules.hostgroups import Hostgroup
from modules.interface import ZabbixInterface from modules.interface import ZabbixInterface
from modules.config import load_config
# Load config # Load config
config = load_config() config = load_config()

View File

@ -11,6 +11,7 @@ from pynetbox import api
from pynetbox.core.query import RequestError as NBRequestError from pynetbox.core.query import RequestError as NBRequestError
from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import ConnectionError as RequestsConnectionError
from zabbix_utils import APIRequestError, ProcessingError, ZabbixAPI from zabbix_utils import APIRequestError, ProcessingError, ZabbixAPI
from modules.config import load_config from modules.config import load_config
from modules.device import PhysicalDevice from modules.device import PhysicalDevice
from modules.exceptions import EnvironmentVarError, SyncError from modules.exceptions import EnvironmentVarError, SyncError
@ -83,14 +84,18 @@ def main(arguments):
device_cfs = list( device_cfs = list(
netbox.extras.custom_fields.filter(type="text", content_types="dcim.device") netbox.extras.custom_fields.filter(type="text", content_types="dcim.device")
) )
verify_hg_format(config["hostgroup_format"], verify_hg_format(
device_cfs=device_cfs, hg_type="dev", logger=logger) config["hostgroup_format"], device_cfs=device_cfs, hg_type="dev", logger=logger
)
if config["sync_vms"]: if config["sync_vms"]:
vm_cfs = list( vm_cfs = list(
netbox.extras.custom_fields.filter(type="text", netbox.extras.custom_fields.filter(
content_types="virtualization.virtualmachine") type="text", content_types="virtualization.virtualmachine"
)
)
verify_hg_format(
config["vm_hostgroup_format"], vm_cfs=vm_cfs, hg_type="vm", logger=logger
) )
verify_hg_format(config["vm_hostgroup_format"], vm_cfs=vm_cfs, hg_type="vm", logger=logger)
# Set Zabbix API # Set Zabbix API
try: try:
ssl_ctx = ssl.create_default_context() ssl_ctx = ssl.create_default_context()
@ -120,7 +125,8 @@ def main(arguments):
netbox_vms = [] netbox_vms = []
if config["sync_vms"]: if config["sync_vms"]:
netbox_vms = list( netbox_vms = list(
netbox.virtualization.virtual_machines.filter(**config["nb_vm_filter"])) netbox.virtualization.virtual_machines.filter(**config["nb_vm_filter"])
)
netbox_site_groups = convert_recordset((netbox.dcim.site_groups.all())) netbox_site_groups = convert_recordset((netbox.dcim.site_groups.all()))
netbox_regions = convert_recordset(netbox.dcim.regions.all()) netbox_regions = convert_recordset(netbox.dcim.regions.all())
netbox_journals = netbox.extras.journal_entries netbox_journals = netbox.extras.journal_entries
@ -141,15 +147,22 @@ def main(arguments):
# Go through all NetBox devices # Go through all NetBox devices
for nb_vm in netbox_vms: for nb_vm in netbox_vms:
try: try:
vm = VirtualMachine(nb_vm, zabbix, netbox_journals, nb_version, vm = VirtualMachine(
config["create_journal"], logger) nb_vm,
zabbix,
netbox_journals,
nb_version,
config["create_journal"],
logger,
)
logger.debug(f"Host {vm.name}: started operations on VM.") logger.debug(f"Host {vm.name}: started operations on VM.")
vm.set_vm_template() vm.set_vm_template()
# Check if a valid template has been found for this VM. # Check if a valid template has been found for this VM.
if not vm.zbx_template_names: if not vm.zbx_template_names:
continue continue
vm.set_hostgroup(config["vm_hostgroup_format"], vm.set_hostgroup(
netbox_site_groups, netbox_regions) config["vm_hostgroup_format"], netbox_site_groups, netbox_regions
)
# Check if a valid hostgroup has been found for this VM. # Check if a valid hostgroup has been found for this VM.
if not vm.hostgroups: if not vm.hostgroups:
continue continue
@ -200,16 +213,25 @@ def main(arguments):
for nb_device in netbox_devices: for nb_device in netbox_devices:
try: try:
# Set device instance set data such as hostgroup and template information. # Set device instance set data such as hostgroup and template information.
device = PhysicalDevice(nb_device, zabbix, netbox_journals, nb_version, device = PhysicalDevice(
config["create_journal"], logger) nb_device,
zabbix,
netbox_journals,
nb_version,
config["create_journal"],
logger,
)
logger.debug(f"Host {device.name}: started operations on device.") logger.debug(f"Host {device.name}: started operations on device.")
device.set_template(config["templates_config_context"], device.set_template(
config["templates_config_context_overrule"]) config["templates_config_context"],
config["templates_config_context_overrule"],
)
# Check if a valid template has been found for this VM. # Check if a valid template has been found for this VM.
if not device.zbx_template_names: if not device.zbx_template_names:
continue continue
device.set_hostgroup( device.set_hostgroup(
config["hostgroup_format"], netbox_site_groups, netbox_regions) config["hostgroup_format"], netbox_site_groups, netbox_regions
)
# Check if a valid hostgroup has been found for this VM. # Check if a valid hostgroup has been found for this VM.
if not device.hostgroups: if not device.hostgroups:
continue continue

View File

@ -1,13 +1,21 @@
"""Tests for configuration parsing in the modules.config module.""" """Tests for configuration parsing in the modules.config module."""
from unittest.mock import patch, MagicMock
import os import os
from modules.config import load_config, DEFAULT_CONFIG, load_config_file, load_env_variable from unittest.mock import MagicMock, patch
from modules.config import (
DEFAULT_CONFIG,
load_config,
load_config_file,
load_env_variable,
)
def test_load_config_defaults(): def test_load_config_defaults():
"""Test that load_config returns default values when no config file or env vars are present""" """Test that load_config returns default values when no config file or env vars are present"""
with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ with patch(
patch('modules.config.load_env_variable', return_value=None): "modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy()
), patch("modules.config.load_env_variable", return_value=None):
config = load_config() config = load_config()
assert config == DEFAULT_CONFIG assert config == DEFAULT_CONFIG
assert config["templates_config_context"] is False assert config["templates_config_context"] is False
@ -20,8 +28,9 @@ def test_load_config_file():
mock_config["templates_config_context"] = True mock_config["templates_config_context"] = True
mock_config["sync_vms"] = True mock_config["sync_vms"] = True
with patch('modules.config.load_config_file', return_value=mock_config), \ with patch("modules.config.load_config_file", return_value=mock_config), patch(
patch('modules.config.load_env_variable', return_value=None): "modules.config.load_env_variable", return_value=None
):
config = load_config() config = load_config()
assert config["templates_config_context"] is True assert config["templates_config_context"] is True
assert config["sync_vms"] is True assert config["sync_vms"] is True
@ -31,6 +40,7 @@ def test_load_config_file():
def test_load_env_variables(): def test_load_env_variables():
"""Test that load_config properly loads values from environment variables""" """Test that load_config properly loads values from environment variables"""
# Mock env variable loading to return values for specific keys # Mock env variable loading to return values for specific keys
def mock_load_env(key): def mock_load_env(key):
if key == "sync_vms": if key == "sync_vms":
@ -39,8 +49,9 @@ def test_load_env_variables():
return True return True
return None return None
with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ with patch(
patch('modules.config.load_env_variable', side_effect=mock_load_env): "modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy()
), patch("modules.config.load_env_variable", side_effect=mock_load_env):
config = load_config() config = load_config()
assert config["sync_vms"] is True assert config["sync_vms"] is True
assert config["create_journal"] is True assert config["create_journal"] is True
@ -60,8 +71,9 @@ def test_env_vars_override_config_file():
return True return True
return None return None
with patch('modules.config.load_config_file', return_value=mock_config), \ with patch("modules.config.load_config_file", return_value=mock_config), patch(
patch('modules.config.load_env_variable', side_effect=mock_load_env): "modules.config.load_env_variable", side_effect=mock_load_env
):
config = load_config() config = load_config()
# This should be overridden by the env var # This should be overridden by the env var
assert config["sync_vms"] is True assert config["sync_vms"] is True
@ -72,8 +84,9 @@ def test_env_vars_override_config_file():
def test_load_config_file_function(): def test_load_config_file_function():
"""Test the load_config_file function directly""" """Test the load_config_file function directly"""
# Test when the file exists # Test when the file exists
with patch('pathlib.Path.exists', return_value=True), \ with patch("pathlib.Path.exists", return_value=True), patch(
patch('importlib.util.spec_from_file_location') as mock_spec: "importlib.util.spec_from_file_location"
) as mock_spec:
# Setup the mock module with attributes # Setup the mock module with attributes
mock_module = MagicMock() mock_module = MagicMock()
mock_module.templates_config_context = True mock_module.templates_config_context = True
@ -85,7 +98,7 @@ def test_load_config_file_function():
mock_spec_instance.loader.exec_module = lambda x: None mock_spec_instance.loader.exec_module = lambda x: None
# Patch module_from_spec to return our mock module # Patch module_from_spec to return our mock module
with patch('importlib.util.module_from_spec', return_value=mock_module): with patch("importlib.util.module_from_spec", return_value=mock_module):
config = load_config_file(DEFAULT_CONFIG.copy()) config = load_config_file(DEFAULT_CONFIG.copy())
assert config["templates_config_context"] is True assert config["templates_config_context"] is True
assert config["sync_vms"] is True assert config["sync_vms"] is True
@ -93,7 +106,7 @@ def test_load_config_file_function():
def test_load_config_file_not_found(): def test_load_config_file_not_found():
"""Test load_config_file when the config file doesn't exist""" """Test load_config_file when the config file doesn't exist"""
with patch('pathlib.Path.exists', return_value=False): with patch("pathlib.Path.exists", return_value=False):
result = load_config_file(DEFAULT_CONFIG.copy()) result = load_config_file(DEFAULT_CONFIG.copy())
# Should return a dict equal to DEFAULT_CONFIG, not a new object # Should return a dict equal to DEFAULT_CONFIG, not a new object
assert result == DEFAULT_CONFIG assert result == DEFAULT_CONFIG
@ -127,8 +140,9 @@ def test_load_config_file_exception_handling():
"""Test that load_config_file handles exceptions gracefully""" """Test that load_config_file handles exceptions gracefully"""
# This test requires modifying the load_config_file function to handle exceptions # This test requires modifying the load_config_file function to handle exceptions
# For now, we're just checking that an exception is raised # For now, we're just checking that an exception is raised
with patch('pathlib.Path.exists', return_value=True), \ with patch("pathlib.Path.exists", return_value=True), patch(
patch('importlib.util.spec_from_file_location', side_effect=Exception("Import error")): "importlib.util.spec_from_file_location", side_effect=Exception("Import error")
):
# Since the current implementation doesn't handle exceptions, we should # Since the current implementation doesn't handle exceptions, we should
# expect an exception to be raised # expect an exception to be raised
try: try:

View File

@ -1,7 +1,10 @@
"""Tests for device deletion functionality in the PhysicalDevice class.""" """Tests for device deletion functionality in the PhysicalDevice class."""
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from zabbix_utils import APIRequestError from zabbix_utils import APIRequestError
from modules.device import PhysicalDevice from modules.device import PhysicalDevice
from modules.exceptions import SyncExternalError from modules.exceptions import SyncExternalError
@ -38,14 +41,14 @@ class TestDeviceDeletion(unittest.TestCase):
self.mock_logger = MagicMock() self.mock_logger = MagicMock()
# Create PhysicalDevice instance with mocks # Create PhysicalDevice instance with mocks
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
self.device = PhysicalDevice( self.device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
journal=True, journal=True,
logger=self.mock_logger logger=self.mock_logger,
) )
def test_cleanup_successful_deletion(self): def test_cleanup_successful_deletion(self):
@ -58,12 +61,15 @@ class TestDeviceDeletion(unittest.TestCase):
self.device.cleanup() self.device.cleanup()
# Verify # Verify
self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) self.mock_zabbix.host.get.assert_called_once_with(
self.mock_zabbix.host.delete.assert_called_once_with('456') filter={"hostid": "456"}, output=[]
)
self.mock_zabbix.host.delete.assert_called_once_with("456")
self.mock_nb_device.save.assert_called_once() self.mock_nb_device.save.assert_called_once()
self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"])
self.mock_logger.info.assert_called_with(f"Host {self.device.name}: " self.mock_logger.info.assert_called_with(
"Deleted host from Zabbix.") f"Host {self.device.name}: " "Deleted host from Zabbix."
)
def test_cleanup_device_already_deleted(self): def test_cleanup_device_already_deleted(self):
"""Test cleanup when device is already deleted from Zabbix.""" """Test cleanup when device is already deleted from Zabbix."""
@ -74,12 +80,15 @@ class TestDeviceDeletion(unittest.TestCase):
self.device.cleanup() self.device.cleanup()
# Verify # Verify
self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) self.mock_zabbix.host.get.assert_called_once_with(
filter={"hostid": "456"}, output=[]
)
self.mock_zabbix.host.delete.assert_not_called() self.mock_zabbix.host.delete.assert_not_called()
self.mock_nb_device.save.assert_called_once() self.mock_nb_device.save.assert_called_once()
self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"])
self.mock_logger.info.assert_called_with( self.mock_logger.info.assert_called_with(
f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox.") f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox."
)
def test_cleanup_api_error(self): def test_cleanup_api_error(self):
"""Test cleanup when Zabbix API returns an error.""" """Test cleanup when Zabbix API returns an error."""
@ -92,15 +101,17 @@ class TestDeviceDeletion(unittest.TestCase):
self.device.cleanup() self.device.cleanup()
# Verify correct calls were made # Verify correct calls were made
self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) self.mock_zabbix.host.get.assert_called_once_with(
self.mock_zabbix.host.delete.assert_called_once_with('456') filter={"hostid": "456"}, output=[]
)
self.mock_zabbix.host.delete.assert_called_once_with("456")
self.mock_nb_device.save.assert_not_called() self.mock_nb_device.save.assert_not_called()
self.mock_logger.error.assert_called() self.mock_logger.error.assert_called()
def test_zeroize_cf(self): def test_zeroize_cf(self):
"""Test _zeroize_cf method that clears the custom field.""" """Test _zeroize_cf method that clears the custom field."""
# Execute # Execute
self.device._zeroize_cf() # pylint: disable=protected-access self.device._zeroize_cf() # pylint: disable=protected-access
# Verify # Verify
self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"])
@ -136,14 +147,14 @@ class TestDeviceDeletion(unittest.TestCase):
def test_create_journal_entry_when_disabled(self): def test_create_journal_entry_when_disabled(self):
"""Test create_journal_entry when journaling is disabled.""" """Test create_journal_entry when journaling is disabled."""
# Setup - create device with journal=False # Setup - create device with journal=False
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
journal=False, # Disable journaling journal=False, # Disable journaling
logger=self.mock_logger logger=self.mock_logger,
) )
# Execute # Execute
@ -159,8 +170,10 @@ class TestDeviceDeletion(unittest.TestCase):
self.mock_zabbix.host.get.return_value = [{"hostid": "456"}] self.mock_zabbix.host.get.return_value = [{"hostid": "456"}]
# Execute # Execute
with patch.object(self.device, 'create_journal_entry') as mock_journal_entry: with patch.object(self.device, "create_journal_entry") as mock_journal_entry:
self.device.cleanup() self.device.cleanup()
# Verify # Verify
mock_journal_entry.assert_called_once_with("warning", "Deleted host from Zabbix") mock_journal_entry.assert_called_once_with(
"warning", "Deleted host from Zabbix"
)

View File

@ -1,8 +1,10 @@
"""Tests for the Hostgroup class in the hostgroups module.""" """Tests for the Hostgroup class in the hostgroups module."""
import unittest import unittest
from unittest.mock import MagicMock, patch, call from unittest.mock import MagicMock, call, patch
from modules.hostgroups import Hostgroup
from modules.exceptions import HostgroupError from modules.exceptions import HostgroupError
from modules.hostgroups import Hostgroup
class TestHostgroups(unittest.TestCase): class TestHostgroups(unittest.TestCase):
@ -17,27 +19,27 @@ class TestHostgroups(unittest.TestCase):
# Create mock device with all properties # Create mock device with all properties
self.mock_device = MagicMock() self.mock_device = MagicMock()
self.mock_device.name = "test-device" self.mock_device.name = "test-device"
# Set up site information # Set up site information
site = MagicMock() site = MagicMock()
site.name = "TestSite" site.name = "TestSite"
# Set up region information # Set up region information
region = MagicMock() region = MagicMock()
region.name = "TestRegion" region.name = "TestRegion"
# Ensure region string representation returns the name # Ensure region string representation returns the name
region.__str__.return_value = "TestRegion" region.__str__.return_value = "TestRegion"
site.region = region site.region = region
# Set up site group information # Set up site group information
site_group = MagicMock() site_group = MagicMock()
site_group.name = "TestSiteGroup" site_group.name = "TestSiteGroup"
# Ensure site group string representation returns the name # Ensure site group string representation returns the name
site_group.__str__.return_value = "TestSiteGroup" site_group.__str__.return_value = "TestSiteGroup"
site.group = site_group site.group = site_group
self.mock_device.site = site self.mock_device.site = site
# Set up role information (varies based on NetBox version) # Set up role information (varies based on NetBox version)
self.mock_device_role = MagicMock() self.mock_device_role = MagicMock()
self.mock_device_role.name = "TestRole" self.mock_device_role.name = "TestRole"
@ -45,7 +47,7 @@ class TestHostgroups(unittest.TestCase):
self.mock_device_role.__str__.return_value = "TestRole" self.mock_device_role.__str__.return_value = "TestRole"
self.mock_device.device_role = self.mock_device_role self.mock_device.device_role = self.mock_device_role
self.mock_device.role = self.mock_device_role self.mock_device.role = self.mock_device_role
# Set up tenant information # Set up tenant information
tenant = MagicMock() tenant = MagicMock()
tenant.name = "TestTenant" tenant.name = "TestTenant"
@ -57,45 +59,45 @@ class TestHostgroups(unittest.TestCase):
tenant_group.__str__.return_value = "TestTenantGroup" tenant_group.__str__.return_value = "TestTenantGroup"
tenant.group = tenant_group tenant.group = tenant_group
self.mock_device.tenant = tenant self.mock_device.tenant = tenant
# Set up platform information # Set up platform information
platform = MagicMock() platform = MagicMock()
platform.name = "TestPlatform" platform.name = "TestPlatform"
self.mock_device.platform = platform self.mock_device.platform = platform
# Device-specific properties # Device-specific properties
device_type = MagicMock() device_type = MagicMock()
manufacturer = MagicMock() manufacturer = MagicMock()
manufacturer.name = "TestManufacturer" manufacturer.name = "TestManufacturer"
device_type.manufacturer = manufacturer device_type.manufacturer = manufacturer
self.mock_device.device_type = device_type self.mock_device.device_type = device_type
location = MagicMock() location = MagicMock()
location.name = "TestLocation" location.name = "TestLocation"
# Ensure location string representation returns the name # Ensure location string representation returns the name
location.__str__.return_value = "TestLocation" location.__str__.return_value = "TestLocation"
self.mock_device.location = location self.mock_device.location = location
# Custom fields # Custom fields
self.mock_device.custom_fields = {"test_cf": "TestCF"} self.mock_device.custom_fields = {"test_cf": "TestCF"}
# *** Mock NetBox VM setup *** # *** Mock NetBox VM setup ***
# Create mock VM with all properties # Create mock VM with all properties
self.mock_vm = MagicMock() self.mock_vm = MagicMock()
self.mock_vm.name = "test-vm" self.mock_vm.name = "test-vm"
# Reuse site from device # Reuse site from device
self.mock_vm.site = site self.mock_vm.site = site
# Set up role for VM # Set up role for VM
self.mock_vm.role = self.mock_device_role self.mock_vm.role = self.mock_device_role
# Set up tenant for VM (same as device) # Set up tenant for VM (same as device)
self.mock_vm.tenant = tenant self.mock_vm.tenant = tenant
# Set up platform for VM (same as device) # Set up platform for VM (same as device)
self.mock_vm.platform = platform self.mock_vm.platform = platform
# VM-specific properties # VM-specific properties
cluster = MagicMock() cluster = MagicMock()
cluster.name = "TestCluster" cluster.name = "TestCluster"
@ -103,28 +105,28 @@ class TestHostgroups(unittest.TestCase):
cluster_type.name = "TestClusterType" cluster_type.name = "TestClusterType"
cluster.type = cluster_type cluster.type = cluster_type
self.mock_vm.cluster = cluster self.mock_vm.cluster = cluster
# Custom fields # Custom fields
self.mock_vm.custom_fields = {"test_cf": "TestCF"} self.mock_vm.custom_fields = {"test_cf": "TestCF"}
# Mock data for nesting tests # Mock data for nesting tests
self.mock_regions_data = [ self.mock_regions_data = [
{"name": "ParentRegion", "parent": None, "_depth": 0}, {"name": "ParentRegion", "parent": None, "_depth": 0},
{"name": "TestRegion", "parent": "ParentRegion", "_depth": 1} {"name": "TestRegion", "parent": "ParentRegion", "_depth": 1},
] ]
self.mock_groups_data = [ self.mock_groups_data = [
{"name": "ParentSiteGroup", "parent": None, "_depth": 0}, {"name": "ParentSiteGroup", "parent": None, "_depth": 0},
{"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1} {"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1},
] ]
def test_device_hostgroup_creation(self): def test_device_hostgroup_creation(self):
"""Test basic device hostgroup creation.""" """Test basic device hostgroup creation."""
hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger)
# Test the string representation # Test the string representation
self.assertEqual(str(hostgroup), "Hostgroup for dev test-device") self.assertEqual(str(hostgroup), "Hostgroup for dev test-device")
# Check format options were set correctly # Check format options were set correctly
self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["site"], "TestSite")
self.assertEqual(hostgroup.format_options["region"], "TestRegion") self.assertEqual(hostgroup.format_options["region"], "TestRegion")
@ -135,14 +137,14 @@ class TestHostgroups(unittest.TestCase):
self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["platform"], "TestPlatform")
self.assertEqual(hostgroup.format_options["manufacturer"], "TestManufacturer") self.assertEqual(hostgroup.format_options["manufacturer"], "TestManufacturer")
self.assertEqual(hostgroup.format_options["location"], "TestLocation") self.assertEqual(hostgroup.format_options["location"], "TestLocation")
def test_vm_hostgroup_creation(self): def test_vm_hostgroup_creation(self):
"""Test basic VM hostgroup creation.""" """Test basic VM hostgroup creation."""
hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger)
# Test the string representation # Test the string representation
self.assertEqual(str(hostgroup), "Hostgroup for vm test-vm") self.assertEqual(str(hostgroup), "Hostgroup for vm test-vm")
# Check format options were set correctly # Check format options were set correctly
self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["site"], "TestSite")
self.assertEqual(hostgroup.format_options["region"], "TestRegion") self.assertEqual(hostgroup.format_options["region"], "TestRegion")
@ -153,78 +155,80 @@ class TestHostgroups(unittest.TestCase):
self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["platform"], "TestPlatform")
self.assertEqual(hostgroup.format_options["cluster"], "TestCluster") self.assertEqual(hostgroup.format_options["cluster"], "TestCluster")
self.assertEqual(hostgroup.format_options["cluster_type"], "TestClusterType") self.assertEqual(hostgroup.format_options["cluster_type"], "TestClusterType")
def test_invalid_object_type(self): def test_invalid_object_type(self):
"""Test that an invalid object type raises an exception.""" """Test that an invalid object type raises an exception."""
with self.assertRaises(HostgroupError): with self.assertRaises(HostgroupError):
Hostgroup("invalid", self.mock_device, "4.0", self.mock_logger) Hostgroup("invalid", self.mock_device, "4.0", self.mock_logger)
def test_device_hostgroup_formats(self): def test_device_hostgroup_formats(self):
"""Test different hostgroup formats for devices.""" """Test different hostgroup formats for devices."""
hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger)
# Default format: site/manufacturer/role # Default format: site/manufacturer/role
default_result = hostgroup.generate() default_result = hostgroup.generate()
self.assertEqual(default_result, "TestSite/TestManufacturer/TestRole") self.assertEqual(default_result, "TestSite/TestManufacturer/TestRole")
# Custom format: site/region # Custom format: site/region
custom_result = hostgroup.generate("site/region") custom_result = hostgroup.generate("site/region")
self.assertEqual(custom_result, "TestSite/TestRegion") self.assertEqual(custom_result, "TestSite/TestRegion")
# Custom format: site/tenant/platform/location # Custom format: site/tenant/platform/location
complex_result = hostgroup.generate("site/tenant/platform/location") complex_result = hostgroup.generate("site/tenant/platform/location")
self.assertEqual(complex_result, "TestSite/TestTenant/TestPlatform/TestLocation") self.assertEqual(
complex_result, "TestSite/TestTenant/TestPlatform/TestLocation"
)
def test_vm_hostgroup_formats(self): def test_vm_hostgroup_formats(self):
"""Test different hostgroup formats for VMs.""" """Test different hostgroup formats for VMs."""
hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger)
# Default format: cluster/role # Default format: cluster/role
default_result = hostgroup.generate() default_result = hostgroup.generate()
self.assertEqual(default_result, "TestCluster/TestRole") self.assertEqual(default_result, "TestCluster/TestRole")
# Custom format: site/tenant # Custom format: site/tenant
custom_result = hostgroup.generate("site/tenant") custom_result = hostgroup.generate("site/tenant")
self.assertEqual(custom_result, "TestSite/TestTenant") self.assertEqual(custom_result, "TestSite/TestTenant")
# Custom format: cluster/cluster_type/platform # Custom format: cluster/cluster_type/platform
complex_result = hostgroup.generate("cluster/cluster_type/platform") complex_result = hostgroup.generate("cluster/cluster_type/platform")
self.assertEqual(complex_result, "TestCluster/TestClusterType/TestPlatform") self.assertEqual(complex_result, "TestCluster/TestClusterType/TestPlatform")
def test_device_netbox_version_differences(self): def test_device_netbox_version_differences(self):
"""Test hostgroup generation with different NetBox versions.""" """Test hostgroup generation with different NetBox versions."""
# NetBox v2.x # NetBox v2.x
hostgroup_v2 = Hostgroup("dev", self.mock_device, "2.11", self.mock_logger) hostgroup_v2 = Hostgroup("dev", self.mock_device, "2.11", self.mock_logger)
self.assertEqual(hostgroup_v2.format_options["role"], "TestRole") self.assertEqual(hostgroup_v2.format_options["role"], "TestRole")
# NetBox v3.x # NetBox v3.x
hostgroup_v3 = Hostgroup("dev", self.mock_device, "3.5", self.mock_logger) hostgroup_v3 = Hostgroup("dev", self.mock_device, "3.5", self.mock_logger)
self.assertEqual(hostgroup_v3.format_options["role"], "TestRole") self.assertEqual(hostgroup_v3.format_options["role"], "TestRole")
# NetBox v4.x (already tested in other methods) # NetBox v4.x (already tested in other methods)
def test_custom_field_lookup(self): def test_custom_field_lookup(self):
"""Test custom field lookup functionality.""" """Test custom field lookup functionality."""
hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger)
# Test custom field exists and is populated # Test custom field exists and is populated
cf_result = hostgroup.custom_field_lookup("test_cf") cf_result = hostgroup.custom_field_lookup("test_cf")
self.assertTrue(cf_result["result"]) self.assertTrue(cf_result["result"])
self.assertEqual(cf_result["cf"], "TestCF") self.assertEqual(cf_result["cf"], "TestCF")
# Test custom field doesn't exist # Test custom field doesn't exist
cf_result = hostgroup.custom_field_lookup("nonexistent_cf") cf_result = hostgroup.custom_field_lookup("nonexistent_cf")
self.assertFalse(cf_result["result"]) self.assertFalse(cf_result["result"])
self.assertIsNone(cf_result["cf"]) self.assertIsNone(cf_result["cf"])
def test_hostgroup_with_custom_field(self): def test_hostgroup_with_custom_field(self):
"""Test hostgroup generation including a custom field.""" """Test hostgroup generation including a custom field."""
hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger)
# Generate with custom field included # Generate with custom field included
result = hostgroup.generate("site/test_cf/role") result = hostgroup.generate("site/test_cf/role")
self.assertEqual(result, "TestSite/TestCF/TestRole") self.assertEqual(result, "TestSite/TestCF/TestRole")
def test_missing_hostgroup_format_item(self): def test_missing_hostgroup_format_item(self):
"""Test handling of missing hostgroup format items.""" """Test handling of missing hostgroup format items."""
# Create a device with minimal attributes # Create a device with minimal attributes
@ -234,31 +238,31 @@ class TestHostgroups(unittest.TestCase):
minimal_device.tenant = None minimal_device.tenant = None
minimal_device.platform = None minimal_device.platform = None
minimal_device.custom_fields = {} minimal_device.custom_fields = {}
# Create role # Create role
role = MagicMock() role = MagicMock()
role.name = "MinimalRole" role.name = "MinimalRole"
minimal_device.role = role minimal_device.role = role
# Create device_type with manufacturer # Create device_type with manufacturer
device_type = MagicMock() device_type = MagicMock()
manufacturer = MagicMock() manufacturer = MagicMock()
manufacturer.name = "MinimalManufacturer" manufacturer.name = "MinimalManufacturer"
device_type.manufacturer = manufacturer device_type.manufacturer = manufacturer
minimal_device.device_type = device_type minimal_device.device_type = device_type
# Create hostgroup # Create hostgroup
hostgroup = Hostgroup("dev", minimal_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", minimal_device, "4.0", self.mock_logger)
# Generate with default format # Generate with default format
result = hostgroup.generate() result = hostgroup.generate()
# Site is missing, so only manufacturer and role should be included # Site is missing, so only manufacturer and role should be included
self.assertEqual(result, "MinimalManufacturer/MinimalRole") self.assertEqual(result, "MinimalManufacturer/MinimalRole")
# Test with invalid format # Test with invalid format
with self.assertRaises(HostgroupError): with self.assertRaises(HostgroupError):
hostgroup.generate("site/nonexistent/role") hostgroup.generate("site/nonexistent/role")
def test_hostgroup_missing_required_attributes(self): def test_hostgroup_missing_required_attributes(self):
"""Test handling when no valid hostgroup can be generated.""" """Test handling when no valid hostgroup can be generated."""
# Create a VM with minimal attributes that won't satisfy any format # Create a VM with minimal attributes that won't satisfy any format
@ -270,69 +274,70 @@ class TestHostgroups(unittest.TestCase):
minimal_vm.role = None minimal_vm.role = None
minimal_vm.cluster = None minimal_vm.cluster = None
minimal_vm.custom_fields = {} minimal_vm.custom_fields = {}
hostgroup = Hostgroup("vm", minimal_vm, "4.0", self.mock_logger) hostgroup = Hostgroup("vm", minimal_vm, "4.0", self.mock_logger)
# With default format of cluster/role, both are None, so should raise an error # With default format of cluster/role, both are None, so should raise an error
with self.assertRaises(HostgroupError): with self.assertRaises(HostgroupError):
hostgroup.generate() hostgroup.generate()
def test_nested_region_hostgroups(self): def test_nested_region_hostgroups(self):
"""Test hostgroup generation with nested regions.""" """Test hostgroup generation with nested regions."""
# Mock the build_path function to return a predictable result # Mock the build_path function to return a predictable result
with patch('modules.hostgroups.build_path') as mock_build_path: with patch("modules.hostgroups.build_path") as mock_build_path:
# Configure the mock to return a list of regions in the path # Configure the mock to return a list of regions in the path
mock_build_path.return_value = ["ParentRegion", "TestRegion"] mock_build_path.return_value = ["ParentRegion", "TestRegion"]
# Create hostgroup with nested regions enabled # Create hostgroup with nested regions enabled
hostgroup = Hostgroup( hostgroup = Hostgroup(
"dev", "dev",
self.mock_device, self.mock_device,
"4.0", "4.0",
self.mock_logger, self.mock_logger,
nested_region_flag=True, nested_region_flag=True,
nb_regions=self.mock_regions_data nb_regions=self.mock_regions_data,
) )
# Generate hostgroup with region # Generate hostgroup with region
result = hostgroup.generate("site/region/role") result = hostgroup.generate("site/region/role")
# Should include the parent region # Should include the parent region
self.assertEqual(result, "TestSite/ParentRegion/TestRegion/TestRole") self.assertEqual(result, "TestSite/ParentRegion/TestRegion/TestRole")
def test_nested_sitegroup_hostgroups(self): def test_nested_sitegroup_hostgroups(self):
"""Test hostgroup generation with nested site groups.""" """Test hostgroup generation with nested site groups."""
# Mock the build_path function to return a predictable result # Mock the build_path function to return a predictable result
with patch('modules.hostgroups.build_path') as mock_build_path: with patch("modules.hostgroups.build_path") as mock_build_path:
# Configure the mock to return a list of site groups in the path # Configure the mock to return a list of site groups in the path
mock_build_path.return_value = ["ParentSiteGroup", "TestSiteGroup"] mock_build_path.return_value = ["ParentSiteGroup", "TestSiteGroup"]
# Create hostgroup with nested site groups enabled # Create hostgroup with nested site groups enabled
hostgroup = Hostgroup( hostgroup = Hostgroup(
"dev", "dev",
self.mock_device, self.mock_device,
"4.0", "4.0",
self.mock_logger, self.mock_logger,
nested_sitegroup_flag=True, nested_sitegroup_flag=True,
nb_groups=self.mock_groups_data nb_groups=self.mock_groups_data,
) )
# Generate hostgroup with site_group # Generate hostgroup with site_group
result = hostgroup.generate("site/site_group/role") result = hostgroup.generate("site/site_group/role")
# Should include the parent site group # Should include the parent site group
self.assertEqual(result, "TestSite/ParentSiteGroup/TestSiteGroup/TestRole") self.assertEqual(result, "TestSite/ParentSiteGroup/TestSiteGroup/TestRole")
def test_list_formatoptions(self): def test_list_formatoptions(self):
"""Test the list_formatoptions method for debugging.""" """Test the list_formatoptions method for debugging."""
hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger)
# Patch sys.stdout to capture print output # Patch sys.stdout to capture print output
with patch('sys.stdout') as mock_stdout: with patch("sys.stdout") as mock_stdout:
hostgroup.list_formatoptions() hostgroup.list_formatoptions()
# Check that print was called with expected output # Check that print was called with expected output
calls = [call.write(f"The following options are available for host test-device"), calls = [
call.write('\n')] call.write(f"The following options are available for host test-device"),
call.write("\n"),
]
mock_stdout.assert_has_calls(calls, any_order=True) mock_stdout.assert_has_calls(calls, any_order=True)

View File

@ -1,7 +1,9 @@
"""Tests for the ZabbixInterface class in the interface module.""" """Tests for the ZabbixInterface class in the interface module."""
import unittest import unittest
from modules.interface import ZabbixInterface
from modules.exceptions import InterfaceConfigError from modules.exceptions import InterfaceConfigError
from modules.interface import ZabbixInterface
class TestZabbixInterface(unittest.TestCase): class TestZabbixInterface(unittest.TestCase):
@ -18,11 +20,7 @@ class TestZabbixInterface(unittest.TestCase):
"zabbix": { "zabbix": {
"interface_type": 2, "interface_type": 2,
"interface_port": "161", "interface_port": "161",
"snmp": { "snmp": {"version": 2, "community": "public", "bulk": 1},
"version": 2,
"community": "public",
"bulk": 1
}
} }
} }
@ -37,16 +35,13 @@ class TestZabbixInterface(unittest.TestCase):
"authpassphrase": "authpass123", "authpassphrase": "authpass123",
"privprotocol": "AES", "privprotocol": "AES",
"privpassphrase": "privpass123", "privpassphrase": "privpass123",
"contextname": "context1" "contextname": "context1",
} },
} }
} }
self.agent_context = { self.agent_context = {
"zabbix": { "zabbix": {"interface_type": 1, "interface_port": "10050"}
"interface_type": 1,
"interface_port": "10050"
}
} }
def test_init(self): def test_init(self):
@ -95,27 +90,27 @@ class TestZabbixInterface(unittest.TestCase):
# Test for agent type (1) # Test for agent type (1)
interface.interface["type"] = 1 interface.interface["type"] = 1
interface._set_default_port() # pylint: disable=protected-access interface._set_default_port() # pylint: disable=protected-access
self.assertEqual(interface.interface["port"], "10050") self.assertEqual(interface.interface["port"], "10050")
# Test for SNMP type (2) # Test for SNMP type (2)
interface.interface["type"] = 2 interface.interface["type"] = 2
interface._set_default_port() # pylint: disable=protected-access interface._set_default_port() # pylint: disable=protected-access
self.assertEqual(interface.interface["port"], "161") self.assertEqual(interface.interface["port"], "161")
# Test for IPMI type (3) # Test for IPMI type (3)
interface.interface["type"] = 3 interface.interface["type"] = 3
interface._set_default_port() # pylint: disable=protected-access interface._set_default_port() # pylint: disable=protected-access
self.assertEqual(interface.interface["port"], "623") self.assertEqual(interface.interface["port"], "623")
# Test for JMX type (4) # Test for JMX type (4)
interface.interface["type"] = 4 interface.interface["type"] = 4
interface._set_default_port() # pylint: disable=protected-access interface._set_default_port() # pylint: disable=protected-access
self.assertEqual(interface.interface["port"], "12345") self.assertEqual(interface.interface["port"], "12345")
# Test for unsupported type # Test for unsupported type
interface.interface["type"] = 99 interface.interface["type"] = 99
result = interface._set_default_port() # pylint: disable=protected-access result = interface._set_default_port() # pylint: disable=protected-access
self.assertFalse(result) self.assertFalse(result)
def test_set_snmp_v2(self): def test_set_snmp_v2(self):
@ -144,9 +139,13 @@ class TestZabbixInterface(unittest.TestCase):
self.assertEqual(interface.interface["details"]["securityname"], "snmpuser") self.assertEqual(interface.interface["details"]["securityname"], "snmpuser")
self.assertEqual(interface.interface["details"]["securitylevel"], "authPriv") self.assertEqual(interface.interface["details"]["securitylevel"], "authPriv")
self.assertEqual(interface.interface["details"]["authprotocol"], "SHA") self.assertEqual(interface.interface["details"]["authprotocol"], "SHA")
self.assertEqual(interface.interface["details"]["authpassphrase"], "authpass123") self.assertEqual(
interface.interface["details"]["authpassphrase"], "authpass123"
)
self.assertEqual(interface.interface["details"]["privprotocol"], "AES") self.assertEqual(interface.interface["details"]["privprotocol"], "AES")
self.assertEqual(interface.interface["details"]["privpassphrase"], "privpass123") self.assertEqual(
interface.interface["details"]["privpassphrase"], "privpass123"
)
self.assertEqual(interface.interface["details"]["contextname"], "context1") self.assertEqual(interface.interface["details"]["contextname"], "context1")
def test_set_snmp_no_snmp_config(self): def test_set_snmp_no_snmp_config(self):
@ -164,12 +163,7 @@ class TestZabbixInterface(unittest.TestCase):
"""Test set_snmp with unsupported SNMP version.""" """Test set_snmp with unsupported SNMP version."""
# Create context with invalid SNMP version # Create context with invalid SNMP version
context = { context = {
"zabbix": { "zabbix": {"interface_type": 2, "snmp": {"version": 4}} # Invalid version
"interface_type": 2,
"snmp": {
"version": 4 # Invalid version
}
}
} }
interface = ZabbixInterface(context, self.test_ip) interface = ZabbixInterface(context, self.test_ip)
interface.get_context() # Set the interface type interface.get_context() # Set the interface type
@ -184,9 +178,7 @@ class TestZabbixInterface(unittest.TestCase):
context = { context = {
"zabbix": { "zabbix": {
"interface_type": 2, "interface_type": 2,
"snmp": { "snmp": {"community": "public"}, # No version specified
"community": "public" # No version specified
}
} }
} }
interface = ZabbixInterface(context, self.test_ip) interface = ZabbixInterface(context, self.test_ip)
@ -214,7 +206,9 @@ class TestZabbixInterface(unittest.TestCase):
self.assertEqual(interface.interface["type"], "2") self.assertEqual(interface.interface["type"], "2")
self.assertEqual(interface.interface["port"], "161") self.assertEqual(interface.interface["port"], "161")
self.assertEqual(interface.interface["details"]["version"], "2") self.assertEqual(interface.interface["details"]["version"], "2")
self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") self.assertEqual(
interface.interface["details"]["community"], "{$SNMP_COMMUNITY}"
)
self.assertEqual(interface.interface["details"]["bulk"], "1") self.assertEqual(interface.interface["details"]["bulk"], "1")
def test_set_default_agent(self): def test_set_default_agent(self):
@ -229,14 +223,7 @@ class TestZabbixInterface(unittest.TestCase):
def test_snmpv2_no_community(self): def test_snmpv2_no_community(self):
"""Test SNMPv2 with no community string specified.""" """Test SNMPv2 with no community string specified."""
# Create context with SNMPv2 but no community # Create context with SNMPv2 but no community
context = { context = {"zabbix": {"interface_type": 2, "snmp": {"version": 2}}}
"zabbix": {
"interface_type": 2,
"snmp": {
"version": 2
}
}
}
interface = ZabbixInterface(context, self.test_ip) interface = ZabbixInterface(context, self.test_ip)
interface.get_context() # Set the interface type interface.get_context() # Set the interface type
@ -244,4 +231,6 @@ class TestZabbixInterface(unittest.TestCase):
interface.set_snmp() interface.set_snmp()
# Should use default community string # Should use default community string
self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") self.assertEqual(
interface.interface["details"]["community"], "{$SNMP_COMMUNITY}"
)

View File

@ -1,8 +1,10 @@
"""Tests for the PhysicalDevice class in the device module.""" """Tests for the PhysicalDevice class in the device module."""
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from modules.device import PhysicalDevice from modules.device import PhysicalDevice
from modules.exceptions import TemplateError, SyncInventoryError from modules.exceptions import SyncInventoryError, TemplateError
class TestPhysicalDevice(unittest.TestCase): class TestPhysicalDevice(unittest.TestCase):
@ -34,24 +36,27 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_logger = MagicMock() self.mock_logger = MagicMock()
# Create PhysicalDevice instance with mocks # Create PhysicalDevice instance with mocks
with patch('modules.device.config', with patch(
{"device_cf": "zabbix_hostid", "modules.device.config",
"template_cf": "zabbix_template", {
"templates_config_context": False, "device_cf": "zabbix_hostid",
"templates_config_context_overrule": False, "template_cf": "zabbix_template",
"traverse_regions": False, "templates_config_context": False,
"traverse_site_groups": False, "templates_config_context_overrule": False,
"inventory_mode": "disabled", "traverse_regions": False,
"inventory_sync": False, "traverse_site_groups": False,
"device_inventory_map": {} "inventory_mode": "disabled",
}): "inventory_sync": False,
"device_inventory_map": {},
},
):
self.device = PhysicalDevice( self.device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
journal=True, journal=True,
logger=self.mock_logger logger=self.mock_logger,
) )
def test_init(self): def test_init(self):
@ -69,14 +74,14 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.primary_ip = None self.mock_nb_device.primary_ip = None
# Creating device should raise SyncInventoryError # Creating device should raise SyncInventoryError
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
with self.assertRaises(SyncInventoryError): with self.assertRaises(SyncInventoryError):
PhysicalDevice( PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
def test_set_basics_with_special_characters(self): def test_set_basics_with_special_characters(self):
@ -86,8 +91,9 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.name = "test-devïce" self.mock_nb_device.name = "test-devïce"
# We need to patch the search function to simulate finding special characters # We need to patch the search function to simulate finding special characters
with patch('modules.device.search') as mock_search, \ with patch("modules.device.search") as mock_search, patch(
patch('modules.device.config', {"device_cf": "zabbix_hostid"}): "modules.device.config", {"device_cf": "zabbix_hostid"}
):
# Make the search function return True to simulate special characters # Make the search function return True to simulate special characters
mock_search.return_value = True mock_search.return_value = True
@ -96,7 +102,7 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# With the mocked search function, the name should be changed to NETBOX_ID format # With the mocked search function, the name should be changed to NETBOX_ID format
@ -110,19 +116,17 @@ class TestPhysicalDevice(unittest.TestCase):
"""Test get_templates_context with valid config.""" """Test get_templates_context with valid config."""
# Set up config_context with valid template data # Set up config_context with valid template data
self.mock_nb_device.config_context = { self.mock_nb_device.config_context = {
"zabbix": { "zabbix": {"templates": ["Template1", "Template2"]}
"templates": ["Template1", "Template2"]
}
} }
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Test that templates are returned correctly # Test that templates are returned correctly
@ -132,20 +136,16 @@ class TestPhysicalDevice(unittest.TestCase):
def test_get_templates_context_with_string(self): def test_get_templates_context_with_string(self):
"""Test get_templates_context with a string instead of list.""" """Test get_templates_context with a string instead of list."""
# Set up config_context with a string template # Set up config_context with a string template
self.mock_nb_device.config_context = { self.mock_nb_device.config_context = {"zabbix": {"templates": "Template1"}}
"zabbix": {
"templates": "Template1"
}
}
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Test that template is wrapped in a list # Test that template is wrapped in a list
@ -158,13 +158,13 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.config_context = {} self.mock_nb_device.config_context = {}
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Test that TemplateError is raised # Test that TemplateError is raised
@ -177,13 +177,13 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.config_context = {"zabbix": {}} self.mock_nb_device.config_context = {"zabbix": {}}
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Test that TemplateError is raised # Test that TemplateError is raised
@ -193,25 +193,25 @@ class TestPhysicalDevice(unittest.TestCase):
def test_set_template_with_config_context(self): def test_set_template_with_config_context(self):
"""Test set_template with templates_config_context=True.""" """Test set_template with templates_config_context=True."""
# Set up config_context with templates # Set up config_context with templates
self.mock_nb_device.config_context = { self.mock_nb_device.config_context = {"zabbix": {"templates": ["Template1"]}}
"zabbix": {
"templates": ["Template1"]
}
}
# Mock get_templates_context to return expected templates # Mock get_templates_context to return expected templates
with patch.object(PhysicalDevice, 'get_templates_context', return_value=["Template1"]): with patch.object(
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): PhysicalDevice, "get_templates_context", return_value=["Template1"]
):
with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call set_template with prefer_config_context=True # Call set_template with prefer_config_context=True
result = device.set_template(prefer_config_context=True, overrule_custom=False) result = device.set_template(
prefer_config_context=True, overrule_custom=False
)
# Check result and template names # Check result and template names
self.assertTrue(result) self.assertTrue(result)
@ -223,20 +223,20 @@ class TestPhysicalDevice(unittest.TestCase):
config_patch = { config_patch = {
"device_cf": "zabbix_hostid", "device_cf": "zabbix_hostid",
"inventory_mode": "disabled", "inventory_mode": "disabled",
"inventory_sync": False "inventory_sync": False,
} }
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call set_inventory with the config patch still active # Call set_inventory with the config patch still active
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
result = device.set_inventory({}) result = device.set_inventory({})
# Check result # Check result
@ -250,20 +250,20 @@ class TestPhysicalDevice(unittest.TestCase):
config_patch = { config_patch = {
"device_cf": "zabbix_hostid", "device_cf": "zabbix_hostid",
"inventory_mode": "manual", "inventory_mode": "manual",
"inventory_sync": False "inventory_sync": False,
} }
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call set_inventory with the config patch still active # Call set_inventory with the config patch still active
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
result = device.set_inventory({}) result = device.set_inventory({})
# Check result # Check result
@ -276,20 +276,20 @@ class TestPhysicalDevice(unittest.TestCase):
config_patch = { config_patch = {
"device_cf": "zabbix_hostid", "device_cf": "zabbix_hostid",
"inventory_mode": "automatic", "inventory_mode": "automatic",
"inventory_sync": False "inventory_sync": False,
} }
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call set_inventory with the config patch still active # Call set_inventory with the config patch still active
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
result = device.set_inventory({}) result = device.set_inventory({})
# Check result # Check result
@ -303,38 +303,31 @@ class TestPhysicalDevice(unittest.TestCase):
"device_cf": "zabbix_hostid", "device_cf": "zabbix_hostid",
"inventory_mode": "manual", "inventory_mode": "manual",
"inventory_sync": True, "inventory_sync": True,
"device_inventory_map": { "device_inventory_map": {"name": "name", "serial": "serialno_a"},
"name": "name",
"serial": "serialno_a"
}
} }
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Create a mock device with the required attributes # Create a mock device with the required attributes
mock_device_data = { mock_device_data = {"name": "test-device", "serial": "ABC123"}
"name": "test-device",
"serial": "ABC123"
}
# Call set_inventory with the config patch still active # Call set_inventory with the config patch still active
with patch('modules.device.config', config_patch): with patch("modules.device.config", config_patch):
result = device.set_inventory(mock_device_data) result = device.set_inventory(mock_device_data)
# Check result # Check result
self.assertTrue(result) self.assertTrue(result)
self.assertEqual(device.inventory_mode, 0) # Manual mode self.assertEqual(device.inventory_mode, 0) # Manual mode
self.assertEqual(device.inventory, { self.assertEqual(
"name": "test-device", device.inventory, {"name": "test-device", "serialno_a": "ABC123"}
"serialno_a": "ABC123" )
})
def test_iscluster_true(self): def test_iscluster_true(self):
"""Test isCluster when device is part of a cluster.""" """Test isCluster when device is part of a cluster."""
@ -342,13 +335,13 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.virtual_chassis = MagicMock() self.mock_nb_device.virtual_chassis = MagicMock()
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Check isCluster result # Check isCluster result
@ -360,26 +353,27 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_nb_device.virtual_chassis = None self.mock_nb_device.virtual_chassis = None
# Create device with the updated mock # Create device with the updated mock
with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): with patch("modules.device.config", {"device_cf": "zabbix_hostid"}):
device = PhysicalDevice( device = PhysicalDevice(
self.mock_nb_device, self.mock_nb_device,
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Check isCluster result # Check isCluster result
self.assertFalse(device.isCluster()) self.assertFalse(device.isCluster())
def test_promote_master_device_primary(self): def test_promote_master_device_primary(self):
"""Test promoteMasterDevice when device is primary in cluster.""" """Test promoteMasterDevice when device is primary in cluster."""
# Set up virtual chassis with master device # Set up virtual chassis with master device
mock_vc = MagicMock() mock_vc = MagicMock()
mock_vc.name = "virtual-chassis-1" mock_vc.name = "virtual-chassis-1"
mock_master = MagicMock() mock_master = MagicMock()
mock_master.id = self.mock_nb_device.id # Set master ID to match the current device mock_master.id = (
self.mock_nb_device.id
) # Set master ID to match the current device
mock_vc.master = mock_master mock_vc.master = mock_master
self.mock_nb_device.virtual_chassis = mock_vc self.mock_nb_device.virtual_chassis = mock_vc
@ -389,7 +383,7 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call promoteMasterDevice and check the result # Call promoteMasterDevice and check the result
@ -400,14 +394,15 @@ class TestPhysicalDevice(unittest.TestCase):
# Device name should be updated to virtual chassis name # Device name should be updated to virtual chassis name
self.assertEqual(device.name, "virtual-chassis-1") self.assertEqual(device.name, "virtual-chassis-1")
def test_promote_master_device_secondary(self): def test_promote_master_device_secondary(self):
"""Test promoteMasterDevice when device is secondary in cluster.""" """Test promoteMasterDevice when device is secondary in cluster."""
# Set up virtual chassis with a different master device # Set up virtual chassis with a different master device
mock_vc = MagicMock() mock_vc = MagicMock()
mock_vc.name = "virtual-chassis-1" mock_vc.name = "virtual-chassis-1"
mock_master = MagicMock() mock_master = MagicMock()
mock_master.id = self.mock_nb_device.id + 1 # Different ID than the current device mock_master.id = (
self.mock_nb_device.id + 1
) # Different ID than the current device
mock_vc.master = mock_master mock_vc.master = mock_master
self.mock_nb_device.virtual_chassis = mock_vc self.mock_nb_device.virtual_chassis = mock_vc
@ -417,7 +412,7 @@ class TestPhysicalDevice(unittest.TestCase):
self.mock_zabbix, self.mock_zabbix,
self.mock_nb_journal, self.mock_nb_journal,
"3.0", "3.0",
logger=self.mock_logger logger=self.mock_logger,
) )
# Call promoteMasterDevice and check the result # Call promoteMasterDevice and check the result

View File

@ -1,5 +1,6 @@
from modules.tools import sanatize_log_output from modules.tools import sanatize_log_output
def test_sanatize_log_output_secrets(): def test_sanatize_log_output_secrets():
data = { data = {
"macros": [ "macros": [
@ -11,6 +12,7 @@ def test_sanatize_log_output_secrets():
assert sanitized["macros"][0]["value"] == "********" assert sanitized["macros"][0]["value"] == "********"
assert sanitized["macros"][1]["value"] == "notsecret" assert sanitized["macros"][1]["value"] == "notsecret"
def test_sanatize_log_output_interface_secrets(): def test_sanatize_log_output_interface_secrets():
data = { data = {
"interfaceid": 123, "interfaceid": 123,
@ -19,8 +21,8 @@ def test_sanatize_log_output_interface_secrets():
"privpassphrase": "anothersecret", "privpassphrase": "anothersecret",
"securityname": "sensitiveuser", "securityname": "sensitiveuser",
"community": "public", "community": "public",
"other": "normalvalue" "other": "normalvalue",
} },
} }
sanitized = sanatize_log_output(data) sanitized = sanatize_log_output(data)
# Sensitive fields should be sanitized # Sensitive fields should be sanitized
@ -33,6 +35,7 @@ def test_sanatize_log_output_interface_secrets():
# interfaceid should be removed # interfaceid should be removed
assert "interfaceid" not in sanitized assert "interfaceid" not in sanitized
def test_sanatize_log_output_interface_macros(): def test_sanatize_log_output_interface_macros():
data = { data = {
"interfaceid": 123, "interfaceid": 123,
@ -41,7 +44,7 @@ def test_sanatize_log_output_interface_macros():
"privpassphrase": "{$SECRET_MACRO}", "privpassphrase": "{$SECRET_MACRO}",
"securityname": "{$USER_MACRO}", "securityname": "{$USER_MACRO}",
"community": "{$SNNMP_COMMUNITY}", "community": "{$SNNMP_COMMUNITY}",
} },
} }
sanitized = sanatize_log_output(data) sanitized = sanatize_log_output(data)
# Macro values should not be sanitized # Macro values should not be sanitized
@ -51,11 +54,13 @@ def test_sanatize_log_output_interface_macros():
assert sanitized["details"]["community"] == "{$SNNMP_COMMUNITY}" assert sanitized["details"]["community"] == "{$SNNMP_COMMUNITY}"
assert "interfaceid" not in sanitized assert "interfaceid" not in sanitized
def test_sanatize_log_output_plain_data(): def test_sanatize_log_output_plain_data():
data = {"foo": "bar", "baz": 123} data = {"foo": "bar", "baz": 123}
sanitized = sanatize_log_output(data) sanitized = sanatize_log_output(data)
assert sanitized == data assert sanitized == data
def test_sanatize_log_output_non_dict(): def test_sanatize_log_output_non_dict():
data = [1, 2, 3] data = [1, 2, 3]
sanitized = sanatize_log_output(data) sanitized = sanatize_log_output(data)

View File

@ -1,8 +1,10 @@
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from modules.device import PhysicalDevice from modules.device import PhysicalDevice
from modules.usermacros import ZabbixUsermacros from modules.usermacros import ZabbixUsermacros
class DummyNB: class DummyNB:
def __init__(self, name="dummy", config_context=None, **kwargs): def __init__(self, name="dummy", config_context=None, **kwargs):
self.name = name self.name = name
@ -18,6 +20,7 @@ class DummyNB:
return self.config_context[key] return self.config_context[key]
raise KeyError(key) raise KeyError(key)
class TestUsermacroSync(unittest.TestCase): class TestUsermacroSync(unittest.TestCase):
def setUp(self): def setUp(self):
self.nb = DummyNB(serial="1234") self.nb = DummyNB(serial="1234")
@ -58,6 +61,7 @@ class TestUsermacroSync(unittest.TestCase):
self.assertIsInstance(device.usermacros, list) self.assertIsInstance(device.usermacros, list)
self.assertGreater(len(device.usermacros), 0) self.assertGreater(len(device.usermacros), 0)
class TestZabbixUsermacros(unittest.TestCase): class TestZabbixUsermacros(unittest.TestCase):
def setUp(self): def setUp(self):
self.nb = DummyNB() self.nb = DummyNB()
@ -78,7 +82,9 @@ class TestZabbixUsermacros(unittest.TestCase):
def test_render_macro_dict(self): def test_render_macro_dict(self):
macros = ZabbixUsermacros(self.nb, {}, False, logger=self.logger) macros = ZabbixUsermacros(self.nb, {}, False, logger=self.logger)
macro = macros.render_macro("{$FOO}", {"value": "bar", "type": "secret", "description": "desc"}) macro = macros.render_macro(
"{$FOO}", {"value": "bar", "type": "secret", "description": "desc"}
)
self.assertEqual(macro["macro"], "{$FOO}") self.assertEqual(macro["macro"], "{$FOO}")
self.assertEqual(macro["value"], "bar") self.assertEqual(macro["value"], "bar")
self.assertEqual(macro["type"], "1") self.assertEqual(macro["type"], "1")
@ -121,5 +127,6 @@ class TestZabbixUsermacros(unittest.TestCase):
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0]["macro"], "{$FOO}") self.assertEqual(result[0]["macro"], "{$FOO}")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()