diff --git a/modules/config.py b/modules/config.py index a105892..9bb8e3f 100644 --- a/modules/config.py +++ b/modules/config.py @@ -1,10 +1,11 @@ """ Module for parsing configuration from the top level config.py file """ -from pathlib import Path + from importlib import util -from os import environ, path from logging import getLogger +from os import environ, path +from pathlib import Path logger = getLogger(__name__) @@ -44,40 +45,40 @@ DEFAULT_CONFIG = { "serial": "serialno_a", "device_type/model": "type", "device_type/manufacturer/name": "vendor", - "oob_ip/address": "oob_ip" + "oob_ip/address": "oob_ip", }, "vm_inventory_map": { "status/label": "deployment_status", "comments": "notes", - "name": "name" + "name": "name", }, "usermacro_sync": False, "device_usermacro_map": { "serial": "{$HW_SERIAL}", "role/name": "{$DEV_ROLE}", "url": "{$NB_URL}", - "id": "{$NB_ID}" + "id": "{$NB_ID}", }, "vm_usermacro_map": { "memory": "{$TOTAL_MEMORY}", "role/name": "{$DEV_ROLE}", "url": "{$NB_URL}", - "id": "{$NB_ID}" + "id": "{$NB_ID}", }, "tag_sync": False, "tag_lower": True, - "tag_name": 'NetBox', + "tag_name": "NetBox", "tag_value": "name", "device_tag_map": { "site/name": "site", "rack/name": "rack", - "platform/name": "target" + "platform/name": "target", }, "vm_tag_map": { "site/name": "site", "cluster/name": "cluster", - "platform/name": "target" - } + "platform/name": "target", + }, } diff --git a/modules/device.py b/modules/device.py index e61cede..d28e670 100644 --- a/modules/device.py +++ b/modules/device.py @@ -5,12 +5,13 @@ Device specific handeling for NetBox to Zabbix from copy import deepcopy from logging import getLogger -from re import search from operator import itemgetter +from re import search -from zabbix_utils import APIRequestError from pynetbox import RequestError as NetboxRequestError +from zabbix_utils import APIRequestError +from modules.config import load_config from modules.exceptions import ( InterfaceConfigError, SyncExternalError, @@ -22,10 +23,10 @@ from modules.interface import ZabbixInterface from modules.tags import ZabbixTags from modules.tools import field_mapper, remove_duplicates, sanatize_log_output from modules.usermacros import ZabbixUsermacros -from modules.config import load_config config = load_config() + class PhysicalDevice: # pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-positional-arguments """ @@ -125,8 +126,8 @@ class PhysicalDevice: self.nb, self.nb_api_version, logger=self.logger, - nested_sitegroup_flag=config['traverse_site_groups'], - nested_region_flag=config['traverse_regions'], + nested_sitegroup_flag=config["traverse_site_groups"], + nested_region_flag=config["traverse_regions"], nb_groups=nb_site_groups, nb_regions=nb_regions, ) @@ -177,8 +178,6 @@ class PhysicalDevice: self.logger.warning(e) raise TemplateError(e) - - def get_templates_context(self): """Get Zabbix templates from the device context""" if "zabbix" not in self.config_context: @@ -203,9 +202,11 @@ class PhysicalDevice: # Set inventory mode. Default is disabled (see class init function). if config["inventory_mode"] == "disabled": if config["inventory_sync"]: - self.logger.error(f"Host {self.name}: Unable to map NetBox inventory to Zabbix. " - "Inventory sync is enabled in " - "config but inventory mode is disabled.") + self.logger.error( + f"Host {self.name}: Unable to map NetBox inventory to Zabbix. " + "Inventory sync is enabled in " + "config but inventory mode is disabled." + ) return True if config["inventory_mode"] == "manual": self.inventory_mode = 0 @@ -403,7 +404,7 @@ class PhysicalDevice: macros = ZabbixUsermacros( self.nb, self._usermacro_map(), - config['usermacro_sync'], + config["usermacro_sync"], logger=self.logger, host=self.name, ) @@ -421,10 +422,10 @@ class PhysicalDevice: tags = ZabbixTags( self.nb, self._tag_map(), - config['tag_sync'], - config['tag_lower'], - tag_name=config['tag_name'], - tag_value=config['tag_value'], + config["tag_sync"], + config["tag_lower"], + tag_name=config["tag_name"], + tag_value=config["tag_value"], logger=self.logger, host=self.name, ) @@ -604,7 +605,9 @@ class PhysicalDevice: ) self.logger.error(e) raise SyncExternalError(e) from None - self.logger.info(f"Host {self.name}: updated with data {sanatize_log_output(kwargs)}.") + self.logger.info( + f"Host {self.name}: updated with data {sanatize_log_output(kwargs)}." + ) self.create_journal_entry("info", "Updated host in Zabbix with latest NB data.") def ConsistencyCheck( @@ -615,7 +618,7 @@ class PhysicalDevice: Checks if Zabbix object is still valid with NetBox parameters. """ # If group is found or if the hostgroup is nested - if not self.setZabbixGroupID(groups): # or len(self.hostgroups.split("/")) > 1: + if not self.setZabbixGroupID(groups): # or len(self.hostgroups.split("/")) > 1: if create_hostgroups: # Script is allowed to create a new hostgroup new_groups = self.createZabbixHostgroup(groups) @@ -632,7 +635,7 @@ class PhysicalDevice: ) self.logger.warning(e) raise SyncInventoryError(e) - #if self.group_ids: + # if self.group_ids: # self.group_ids.append(self.pri_group_id) # Prepare templates and proxy config @@ -704,8 +707,9 @@ class PhysicalDevice: if str(self.zabbix.version).startswith(("6", "5")): group_dictname = "groups" # Check if hostgroups match - if (sorted(host[group_dictname], key=itemgetter('groupid')) == - sorted(self.group_ids, key=itemgetter('groupid'))): + if sorted(host[group_dictname], key=itemgetter("groupid")) == sorted( + self.group_ids, key=itemgetter("groupid") + ): self.logger.debug(f"Host {self.name}: hostgroups in-sync.") else: self.logger.warning(f"Host {self.name}: hostgroups OUT of sync.") @@ -720,8 +724,10 @@ class PhysicalDevice: # Check if a proxy has been defined if self.zbxproxy: # Check if proxy or proxy group is defined - if (self.zbxproxy["idtype"] in host and - host[self.zbxproxy["idtype"]] == self.zbxproxy["id"]): + if ( + self.zbxproxy["idtype"] in host + and host[self.zbxproxy["idtype"]] == self.zbxproxy["id"] + ): self.logger.debug(f"Host {self.name}: proxy in-sync.") # Backwards compatibility for Zabbix <= 6 elif "proxy_hostid" in host and host["proxy_hostid"] == self.zbxproxy["id"]: @@ -788,21 +794,23 @@ class PhysicalDevice: self.updateZabbixHost(inventory=self.inventory) # Check host usermacros - if config['usermacro_sync']: + if config["usermacro_sync"]: # Make a full copy synce we dont want to lose the original value # of secret type macros from Netbox netbox_macros = deepcopy(self.usermacros) # Set the sync bit - full_sync_bit = bool(str(config['usermacro_sync']).lower() == "full") + full_sync_bit = bool(str(config["usermacro_sync"]).lower() == "full") for macro in netbox_macros: # If the Macro is a secret and full sync is NOT activated if macro["type"] == str(1) and not full_sync_bit: # Remove the value as the Zabbix api does not return the value key # This is required when you want to do a diff between both lists macro.pop("value") + # Sort all lists def filter_with_macros(macro): return macro["macro"] + host["macros"].sort(key=filter_with_macros) netbox_macros.sort(key=filter_with_macros) # Check if both lists are the same @@ -814,7 +822,7 @@ class PhysicalDevice: self.updateZabbixHost(macros=self.usermacros) # Check host tags - if config['tag_sync']: + if config["tag_sync"]: if remove_duplicates(host["tags"], sortkey="tag") == self.tags: self.logger.debug(f"Host {self.name}: tags in-sync.") else: @@ -870,8 +878,10 @@ class PhysicalDevice: try: # API call to Zabbix self.zabbix.hostinterface.update(updates) - e = (f"Host {self.name}: updated interface " - f"with data {sanatize_log_output(updates)}.") + e = ( + f"Host {self.name}: updated interface " + f"with data {sanatize_log_output(updates)}." + ) self.logger.info(e) self.create_journal_entry("info", e) except APIRequestError as e: diff --git a/modules/tools.py b/modules/tools.py index 823410e..13ba05d 100644 --- a/modules/tools.py +++ b/modules/tools.py @@ -1,6 +1,8 @@ """A collection of tools used by several classes""" + from modules.exceptions import HostgroupError + def convert_recordset(recordset): """Converts netbox RedcordSet to list of dicts.""" recordlist = [] @@ -101,7 +103,9 @@ def remove_duplicates(input_list, sortkey=None): return output_list -def verify_hg_format(hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", logger=None): +def verify_hg_format( + hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", logger=None +): """ Verifies hostgroup field format """ @@ -109,44 +113,51 @@ def verify_hg_format(hg_format, device_cfs=None, vm_cfs=None, hg_type="dev", log device_cfs = [] if not vm_cfs: vm_cfs = [] - allowed_objects = {"dev": ["location", - "rack", - "role", - "manufacturer", - "region", - "site", - "site_group", - "tenant", - "tenant_group", - "platform", - "cluster"] - ,"vm": ["cluster_type", - "role", - "manufacturer", - "region", - "site", - "site_group", - "tenant", - "tenant_group", - "cluster", - "device", - "platform"] - ,"cfs": {"dev": [], "vm": []} - } + allowed_objects = { + "dev": [ + "location", + "rack", + "role", + "manufacturer", + "region", + "site", + "site_group", + "tenant", + "tenant_group", + "platform", + "cluster", + ], + "vm": [ + "cluster_type", + "role", + "manufacturer", + "region", + "site", + "site_group", + "tenant", + "tenant_group", + "cluster", + "device", + "platform", + ], + "cfs": {"dev": [], "vm": []}, + } for cf in device_cfs: - allowed_objects['cfs']['dev'].append(cf.name) + allowed_objects["cfs"]["dev"].append(cf.name) for cf in vm_cfs: - allowed_objects['cfs']['vm'].append(cf.name) + allowed_objects["cfs"]["vm"].append(cf.name) hg_objects = [] - if isinstance(hg_format,list): + if isinstance(hg_format, list): for f in hg_format: hg_objects = hg_objects + f.split("/") else: hg_objects = hg_format.split("/") hg_objects = sorted(set(hg_objects)) for hg_object in hg_objects: - if (hg_object not in allowed_objects[hg_type] and - hg_object not in allowed_objects['cfs'][hg_type]): + if ( + hg_object not in allowed_objects[hg_type] + and hg_object not in allowed_objects["cfs"][hg_type] + ): e = ( f"Hostgroup item {hg_object} is not valid. Make sure you" " use valid items and separate them with '/'." diff --git a/modules/usermacros.py b/modules/usermacros.py index 6d396c8..0719106 100644 --- a/modules/usermacros.py +++ b/modules/usermacros.py @@ -57,8 +57,10 @@ class ZabbixUsermacros: macro["macro"] = str(macro_name) if isinstance(macro_properties, dict): if not "value" in macro_properties: - self.logger.warning(f"Host {self.name}: Usermacro {macro_name} has " - "no value in Netbox, skipping.") + self.logger.warning( + f"Host {self.name}: Usermacro {macro_name} has " + "no value in Netbox, skipping." + ) return False macro["value"] = macro_properties["value"] @@ -83,8 +85,10 @@ class ZabbixUsermacros: macro["description"] = "" else: - self.logger.warning(f"Host {self.name}: Usermacro {macro_name} " - "has no value, skipping.") + self.logger.warning( + f"Host {self.name}: Usermacro {macro_name} " + "has no value, skipping." + ) return False else: self.logger.error( diff --git a/modules/virtual_machine.py b/modules/virtual_machine.py index e0f7abb..59a4325 100644 --- a/modules/virtual_machine.py +++ b/modules/virtual_machine.py @@ -1,10 +1,11 @@ # pylint: disable=duplicate-code """Module that hosts all functions for virtual machine processing""" +from modules.config import load_config from modules.device import PhysicalDevice from modules.exceptions import InterfaceConfigError, SyncInventoryError, TemplateError from modules.hostgroups import Hostgroup from modules.interface import ZabbixInterface -from modules.config import load_config + # Load config config = load_config() diff --git a/netbox_zabbix_sync.py b/netbox_zabbix_sync.py index d9ff71b..e8b779f 100755 --- a/netbox_zabbix_sync.py +++ b/netbox_zabbix_sync.py @@ -11,6 +11,7 @@ from pynetbox import api from pynetbox.core.query import RequestError as NBRequestError from requests.exceptions import ConnectionError as RequestsConnectionError from zabbix_utils import APIRequestError, ProcessingError, ZabbixAPI + from modules.config import load_config from modules.device import PhysicalDevice from modules.exceptions import EnvironmentVarError, SyncError @@ -83,14 +84,18 @@ def main(arguments): device_cfs = list( netbox.extras.custom_fields.filter(type="text", content_types="dcim.device") ) - verify_hg_format(config["hostgroup_format"], - device_cfs=device_cfs, hg_type="dev", logger=logger) + verify_hg_format( + config["hostgroup_format"], device_cfs=device_cfs, hg_type="dev", logger=logger + ) if config["sync_vms"]: vm_cfs = list( - netbox.extras.custom_fields.filter(type="text", - content_types="virtualization.virtualmachine") + netbox.extras.custom_fields.filter( + type="text", content_types="virtualization.virtualmachine" + ) + ) + verify_hg_format( + config["vm_hostgroup_format"], vm_cfs=vm_cfs, hg_type="vm", logger=logger ) - verify_hg_format(config["vm_hostgroup_format"], vm_cfs=vm_cfs, hg_type="vm", logger=logger) # Set Zabbix API try: ssl_ctx = ssl.create_default_context() @@ -120,7 +125,8 @@ def main(arguments): netbox_vms = [] if config["sync_vms"]: netbox_vms = list( - netbox.virtualization.virtual_machines.filter(**config["nb_vm_filter"])) + netbox.virtualization.virtual_machines.filter(**config["nb_vm_filter"]) + ) netbox_site_groups = convert_recordset((netbox.dcim.site_groups.all())) netbox_regions = convert_recordset(netbox.dcim.regions.all()) netbox_journals = netbox.extras.journal_entries @@ -141,15 +147,22 @@ def main(arguments): # Go through all NetBox devices for nb_vm in netbox_vms: try: - vm = VirtualMachine(nb_vm, zabbix, netbox_journals, nb_version, - config["create_journal"], logger) + vm = VirtualMachine( + nb_vm, + zabbix, + netbox_journals, + nb_version, + config["create_journal"], + logger, + ) logger.debug(f"Host {vm.name}: started operations on VM.") vm.set_vm_template() # Check if a valid template has been found for this VM. if not vm.zbx_template_names: continue - vm.set_hostgroup(config["vm_hostgroup_format"], - netbox_site_groups, netbox_regions) + vm.set_hostgroup( + config["vm_hostgroup_format"], netbox_site_groups, netbox_regions + ) # Check if a valid hostgroup has been found for this VM. if not vm.hostgroups: continue @@ -200,16 +213,25 @@ def main(arguments): for nb_device in netbox_devices: try: # Set device instance set data such as hostgroup and template information. - device = PhysicalDevice(nb_device, zabbix, netbox_journals, nb_version, - config["create_journal"], logger) + device = PhysicalDevice( + nb_device, + zabbix, + netbox_journals, + nb_version, + config["create_journal"], + logger, + ) logger.debug(f"Host {device.name}: started operations on device.") - device.set_template(config["templates_config_context"], - config["templates_config_context_overrule"]) + device.set_template( + config["templates_config_context"], + config["templates_config_context_overrule"], + ) # Check if a valid template has been found for this VM. if not device.zbx_template_names: continue device.set_hostgroup( - config["hostgroup_format"], netbox_site_groups, netbox_regions) + config["hostgroup_format"], netbox_site_groups, netbox_regions + ) # Check if a valid hostgroup has been found for this VM. if not device.hostgroups: continue diff --git a/tests/test_configuration_parsing.py b/tests/test_configuration_parsing.py index 641b508..6bf9c89 100644 --- a/tests/test_configuration_parsing.py +++ b/tests/test_configuration_parsing.py @@ -1,13 +1,21 @@ """Tests for configuration parsing in the modules.config module.""" -from unittest.mock import patch, MagicMock + import os -from modules.config import load_config, DEFAULT_CONFIG, load_config_file, load_env_variable +from unittest.mock import MagicMock, patch + +from modules.config import ( + DEFAULT_CONFIG, + load_config, + load_config_file, + load_env_variable, +) def test_load_config_defaults(): """Test that load_config returns default values when no config file or env vars are present""" - with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ - patch('modules.config.load_env_variable', return_value=None): + with patch( + "modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy() + ), patch("modules.config.load_env_variable", return_value=None): config = load_config() assert config == DEFAULT_CONFIG assert config["templates_config_context"] is False @@ -20,8 +28,9 @@ def test_load_config_file(): mock_config["templates_config_context"] = True mock_config["sync_vms"] = True - with patch('modules.config.load_config_file', return_value=mock_config), \ - patch('modules.config.load_env_variable', return_value=None): + with patch("modules.config.load_config_file", return_value=mock_config), patch( + "modules.config.load_env_variable", return_value=None + ): config = load_config() assert config["templates_config_context"] is True assert config["sync_vms"] is True @@ -31,6 +40,7 @@ def test_load_config_file(): def test_load_env_variables(): """Test that load_config properly loads values from environment variables""" + # Mock env variable loading to return values for specific keys def mock_load_env(key): if key == "sync_vms": @@ -39,8 +49,9 @@ def test_load_env_variables(): return True return None - with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ - patch('modules.config.load_env_variable', side_effect=mock_load_env): + with patch( + "modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy() + ), patch("modules.config.load_env_variable", side_effect=mock_load_env): config = load_config() assert config["sync_vms"] is True assert config["create_journal"] is True @@ -60,8 +71,9 @@ def test_env_vars_override_config_file(): return True return None - with patch('modules.config.load_config_file', return_value=mock_config), \ - patch('modules.config.load_env_variable', side_effect=mock_load_env): + with patch("modules.config.load_config_file", return_value=mock_config), patch( + "modules.config.load_env_variable", side_effect=mock_load_env + ): config = load_config() # This should be overridden by the env var assert config["sync_vms"] is True @@ -72,8 +84,9 @@ def test_env_vars_override_config_file(): def test_load_config_file_function(): """Test the load_config_file function directly""" # Test when the file exists - with patch('pathlib.Path.exists', return_value=True), \ - patch('importlib.util.spec_from_file_location') as mock_spec: + with patch("pathlib.Path.exists", return_value=True), patch( + "importlib.util.spec_from_file_location" + ) as mock_spec: # Setup the mock module with attributes mock_module = MagicMock() mock_module.templates_config_context = True @@ -85,7 +98,7 @@ def test_load_config_file_function(): mock_spec_instance.loader.exec_module = lambda x: None # Patch module_from_spec to return our mock module - with patch('importlib.util.module_from_spec', return_value=mock_module): + with patch("importlib.util.module_from_spec", return_value=mock_module): config = load_config_file(DEFAULT_CONFIG.copy()) assert config["templates_config_context"] is True assert config["sync_vms"] is True @@ -93,7 +106,7 @@ def test_load_config_file_function(): def test_load_config_file_not_found(): """Test load_config_file when the config file doesn't exist""" - with patch('pathlib.Path.exists', return_value=False): + with patch("pathlib.Path.exists", return_value=False): result = load_config_file(DEFAULT_CONFIG.copy()) # Should return a dict equal to DEFAULT_CONFIG, not a new object assert result == DEFAULT_CONFIG @@ -127,8 +140,9 @@ def test_load_config_file_exception_handling(): """Test that load_config_file handles exceptions gracefully""" # This test requires modifying the load_config_file function to handle exceptions # For now, we're just checking that an exception is raised - with patch('pathlib.Path.exists', return_value=True), \ - patch('importlib.util.spec_from_file_location', side_effect=Exception("Import error")): + with patch("pathlib.Path.exists", return_value=True), patch( + "importlib.util.spec_from_file_location", side_effect=Exception("Import error") + ): # Since the current implementation doesn't handle exceptions, we should # expect an exception to be raised try: diff --git a/tests/test_device_deletion.py b/tests/test_device_deletion.py index 392ba1a..2e1126d 100644 --- a/tests/test_device_deletion.py +++ b/tests/test_device_deletion.py @@ -1,7 +1,10 @@ """Tests for device deletion functionality in the PhysicalDevice class.""" + import unittest from unittest.mock import MagicMock, patch + from zabbix_utils import APIRequestError + from modules.device import PhysicalDevice from modules.exceptions import SyncExternalError @@ -38,14 +41,14 @@ class TestDeviceDeletion(unittest.TestCase): self.mock_logger = MagicMock() # Create PhysicalDevice instance with mocks - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): self.device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=True, - logger=self.mock_logger + logger=self.mock_logger, ) def test_cleanup_successful_deletion(self): @@ -58,12 +61,15 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) - self.mock_zabbix.host.delete.assert_called_once_with('456') + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) + self.mock_zabbix.host.delete.assert_called_once_with("456") self.mock_nb_device.save.assert_called_once() self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) - self.mock_logger.info.assert_called_with(f"Host {self.device.name}: " - "Deleted host from Zabbix.") + self.mock_logger.info.assert_called_with( + f"Host {self.device.name}: " "Deleted host from Zabbix." + ) def test_cleanup_device_already_deleted(self): """Test cleanup when device is already deleted from Zabbix.""" @@ -74,12 +80,15 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) self.mock_zabbix.host.delete.assert_not_called() self.mock_nb_device.save.assert_called_once() self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) self.mock_logger.info.assert_called_with( - f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox.") + f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox." + ) def test_cleanup_api_error(self): """Test cleanup when Zabbix API returns an error.""" @@ -92,15 +101,17 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify correct calls were made - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) - self.mock_zabbix.host.delete.assert_called_once_with('456') + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) + self.mock_zabbix.host.delete.assert_called_once_with("456") self.mock_nb_device.save.assert_not_called() self.mock_logger.error.assert_called() def test_zeroize_cf(self): """Test _zeroize_cf method that clears the custom field.""" # Execute - self.device._zeroize_cf() # pylint: disable=protected-access + self.device._zeroize_cf() # pylint: disable=protected-access # Verify self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) @@ -136,14 +147,14 @@ class TestDeviceDeletion(unittest.TestCase): def test_create_journal_entry_when_disabled(self): """Test create_journal_entry when journaling is disabled.""" # Setup - create device with journal=False - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=False, # Disable journaling - logger=self.mock_logger + logger=self.mock_logger, ) # Execute @@ -159,8 +170,10 @@ class TestDeviceDeletion(unittest.TestCase): self.mock_zabbix.host.get.return_value = [{"hostid": "456"}] # Execute - with patch.object(self.device, 'create_journal_entry') as mock_journal_entry: + with patch.object(self.device, "create_journal_entry") as mock_journal_entry: self.device.cleanup() # Verify - mock_journal_entry.assert_called_once_with("warning", "Deleted host from Zabbix") + mock_journal_entry.assert_called_once_with( + "warning", "Deleted host from Zabbix" + ) diff --git a/tests/test_hostgroups.py b/tests/test_hostgroups.py index 1e652ec..6e79b20 100644 --- a/tests/test_hostgroups.py +++ b/tests/test_hostgroups.py @@ -1,8 +1,10 @@ """Tests for the Hostgroup class in the hostgroups module.""" + import unittest -from unittest.mock import MagicMock, patch, call -from modules.hostgroups import Hostgroup +from unittest.mock import MagicMock, call, patch + from modules.exceptions import HostgroupError +from modules.hostgroups import Hostgroup class TestHostgroups(unittest.TestCase): @@ -17,27 +19,27 @@ class TestHostgroups(unittest.TestCase): # Create mock device with all properties self.mock_device = MagicMock() self.mock_device.name = "test-device" - + # Set up site information site = MagicMock() site.name = "TestSite" - + # Set up region information region = MagicMock() region.name = "TestRegion" # Ensure region string representation returns the name region.__str__.return_value = "TestRegion" site.region = region - + # Set up site group information site_group = MagicMock() site_group.name = "TestSiteGroup" # Ensure site group string representation returns the name site_group.__str__.return_value = "TestSiteGroup" site.group = site_group - + self.mock_device.site = site - + # Set up role information (varies based on NetBox version) self.mock_device_role = MagicMock() self.mock_device_role.name = "TestRole" @@ -45,7 +47,7 @@ class TestHostgroups(unittest.TestCase): self.mock_device_role.__str__.return_value = "TestRole" self.mock_device.device_role = self.mock_device_role self.mock_device.role = self.mock_device_role - + # Set up tenant information tenant = MagicMock() tenant.name = "TestTenant" @@ -57,45 +59,45 @@ class TestHostgroups(unittest.TestCase): tenant_group.__str__.return_value = "TestTenantGroup" tenant.group = tenant_group self.mock_device.tenant = tenant - + # Set up platform information platform = MagicMock() platform.name = "TestPlatform" self.mock_device.platform = platform - + # Device-specific properties device_type = MagicMock() manufacturer = MagicMock() manufacturer.name = "TestManufacturer" device_type.manufacturer = manufacturer self.mock_device.device_type = device_type - + location = MagicMock() location.name = "TestLocation" # Ensure location string representation returns the name location.__str__.return_value = "TestLocation" self.mock_device.location = location - + # Custom fields self.mock_device.custom_fields = {"test_cf": "TestCF"} - + # *** Mock NetBox VM setup *** # Create mock VM with all properties self.mock_vm = MagicMock() self.mock_vm.name = "test-vm" - + # Reuse site from device self.mock_vm.site = site - + # Set up role for VM self.mock_vm.role = self.mock_device_role - + # Set up tenant for VM (same as device) self.mock_vm.tenant = tenant - + # Set up platform for VM (same as device) self.mock_vm.platform = platform - + # VM-specific properties cluster = MagicMock() cluster.name = "TestCluster" @@ -103,28 +105,28 @@ class TestHostgroups(unittest.TestCase): cluster_type.name = "TestClusterType" cluster.type = cluster_type self.mock_vm.cluster = cluster - + # Custom fields self.mock_vm.custom_fields = {"test_cf": "TestCF"} - + # Mock data for nesting tests self.mock_regions_data = [ {"name": "ParentRegion", "parent": None, "_depth": 0}, - {"name": "TestRegion", "parent": "ParentRegion", "_depth": 1} + {"name": "TestRegion", "parent": "ParentRegion", "_depth": 1}, ] - + self.mock_groups_data = [ {"name": "ParentSiteGroup", "parent": None, "_depth": 0}, - {"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1} + {"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1}, ] def test_device_hostgroup_creation(self): """Test basic device hostgroup creation.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Test the string representation self.assertEqual(str(hostgroup), "Hostgroup for dev test-device") - + # Check format options were set correctly self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["region"], "TestRegion") @@ -135,14 +137,14 @@ class TestHostgroups(unittest.TestCase): self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["manufacturer"], "TestManufacturer") self.assertEqual(hostgroup.format_options["location"], "TestLocation") - + def test_vm_hostgroup_creation(self): """Test basic VM hostgroup creation.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Test the string representation self.assertEqual(str(hostgroup), "Hostgroup for vm test-vm") - + # Check format options were set correctly self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["region"], "TestRegion") @@ -153,78 +155,80 @@ class TestHostgroups(unittest.TestCase): self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["cluster"], "TestCluster") self.assertEqual(hostgroup.format_options["cluster_type"], "TestClusterType") - + def test_invalid_object_type(self): """Test that an invalid object type raises an exception.""" with self.assertRaises(HostgroupError): Hostgroup("invalid", self.mock_device, "4.0", self.mock_logger) - + def test_device_hostgroup_formats(self): """Test different hostgroup formats for devices.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Default format: site/manufacturer/role default_result = hostgroup.generate() self.assertEqual(default_result, "TestSite/TestManufacturer/TestRole") - + # Custom format: site/region custom_result = hostgroup.generate("site/region") self.assertEqual(custom_result, "TestSite/TestRegion") - + # Custom format: site/tenant/platform/location complex_result = hostgroup.generate("site/tenant/platform/location") - self.assertEqual(complex_result, "TestSite/TestTenant/TestPlatform/TestLocation") - + self.assertEqual( + complex_result, "TestSite/TestTenant/TestPlatform/TestLocation" + ) + def test_vm_hostgroup_formats(self): """Test different hostgroup formats for VMs.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Default format: cluster/role default_result = hostgroup.generate() self.assertEqual(default_result, "TestCluster/TestRole") - + # Custom format: site/tenant custom_result = hostgroup.generate("site/tenant") self.assertEqual(custom_result, "TestSite/TestTenant") - + # Custom format: cluster/cluster_type/platform complex_result = hostgroup.generate("cluster/cluster_type/platform") self.assertEqual(complex_result, "TestCluster/TestClusterType/TestPlatform") - + def test_device_netbox_version_differences(self): """Test hostgroup generation with different NetBox versions.""" # NetBox v2.x hostgroup_v2 = Hostgroup("dev", self.mock_device, "2.11", self.mock_logger) self.assertEqual(hostgroup_v2.format_options["role"], "TestRole") - + # NetBox v3.x hostgroup_v3 = Hostgroup("dev", self.mock_device, "3.5", self.mock_logger) self.assertEqual(hostgroup_v3.format_options["role"], "TestRole") - + # NetBox v4.x (already tested in other methods) - + def test_custom_field_lookup(self): """Test custom field lookup functionality.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Test custom field exists and is populated cf_result = hostgroup.custom_field_lookup("test_cf") self.assertTrue(cf_result["result"]) self.assertEqual(cf_result["cf"], "TestCF") - + # Test custom field doesn't exist cf_result = hostgroup.custom_field_lookup("nonexistent_cf") self.assertFalse(cf_result["result"]) self.assertIsNone(cf_result["cf"]) - + def test_hostgroup_with_custom_field(self): """Test hostgroup generation including a custom field.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Generate with custom field included result = hostgroup.generate("site/test_cf/role") self.assertEqual(result, "TestSite/TestCF/TestRole") - + def test_missing_hostgroup_format_item(self): """Test handling of missing hostgroup format items.""" # Create a device with minimal attributes @@ -234,31 +238,31 @@ class TestHostgroups(unittest.TestCase): minimal_device.tenant = None minimal_device.platform = None minimal_device.custom_fields = {} - + # Create role role = MagicMock() role.name = "MinimalRole" minimal_device.role = role - + # Create device_type with manufacturer device_type = MagicMock() manufacturer = MagicMock() manufacturer.name = "MinimalManufacturer" device_type.manufacturer = manufacturer minimal_device.device_type = device_type - + # Create hostgroup hostgroup = Hostgroup("dev", minimal_device, "4.0", self.mock_logger) - + # Generate with default format result = hostgroup.generate() # Site is missing, so only manufacturer and role should be included self.assertEqual(result, "MinimalManufacturer/MinimalRole") - + # Test with invalid format with self.assertRaises(HostgroupError): hostgroup.generate("site/nonexistent/role") - + def test_hostgroup_missing_required_attributes(self): """Test handling when no valid hostgroup can be generated.""" # Create a VM with minimal attributes that won't satisfy any format @@ -270,69 +274,70 @@ class TestHostgroups(unittest.TestCase): minimal_vm.role = None minimal_vm.cluster = None minimal_vm.custom_fields = {} - + hostgroup = Hostgroup("vm", minimal_vm, "4.0", self.mock_logger) - + # With default format of cluster/role, both are None, so should raise an error with self.assertRaises(HostgroupError): hostgroup.generate() - + def test_nested_region_hostgroups(self): """Test hostgroup generation with nested regions.""" # Mock the build_path function to return a predictable result - with patch('modules.hostgroups.build_path') as mock_build_path: + with patch("modules.hostgroups.build_path") as mock_build_path: # Configure the mock to return a list of regions in the path mock_build_path.return_value = ["ParentRegion", "TestRegion"] - + # Create hostgroup with nested regions enabled hostgroup = Hostgroup( - "dev", - self.mock_device, - "4.0", + "dev", + self.mock_device, + "4.0", self.mock_logger, nested_region_flag=True, - nb_regions=self.mock_regions_data + nb_regions=self.mock_regions_data, ) - + # Generate hostgroup with region result = hostgroup.generate("site/region/role") # Should include the parent region self.assertEqual(result, "TestSite/ParentRegion/TestRegion/TestRole") - + def test_nested_sitegroup_hostgroups(self): """Test hostgroup generation with nested site groups.""" # Mock the build_path function to return a predictable result - with patch('modules.hostgroups.build_path') as mock_build_path: + with patch("modules.hostgroups.build_path") as mock_build_path: # Configure the mock to return a list of site groups in the path mock_build_path.return_value = ["ParentSiteGroup", "TestSiteGroup"] - + # Create hostgroup with nested site groups enabled hostgroup = Hostgroup( - "dev", - self.mock_device, - "4.0", + "dev", + self.mock_device, + "4.0", self.mock_logger, nested_sitegroup_flag=True, - nb_groups=self.mock_groups_data + nb_groups=self.mock_groups_data, ) - + # Generate hostgroup with site_group result = hostgroup.generate("site/site_group/role") # Should include the parent site group self.assertEqual(result, "TestSite/ParentSiteGroup/TestSiteGroup/TestRole") - def test_list_formatoptions(self): """Test the list_formatoptions method for debugging.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Patch sys.stdout to capture print output - with patch('sys.stdout') as mock_stdout: + with patch("sys.stdout") as mock_stdout: hostgroup.list_formatoptions() - + # Check that print was called with expected output - calls = [call.write(f"The following options are available for host test-device"), - call.write('\n')] + calls = [ + call.write(f"The following options are available for host test-device"), + call.write("\n"), + ] mock_stdout.assert_has_calls(calls, any_order=True) diff --git a/tests/test_interface.py b/tests/test_interface.py index ff55218..4f2debd 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,7 +1,9 @@ """Tests for the ZabbixInterface class in the interface module.""" + import unittest -from modules.interface import ZabbixInterface + from modules.exceptions import InterfaceConfigError +from modules.interface import ZabbixInterface class TestZabbixInterface(unittest.TestCase): @@ -18,11 +20,7 @@ class TestZabbixInterface(unittest.TestCase): "zabbix": { "interface_type": 2, "interface_port": "161", - "snmp": { - "version": 2, - "community": "public", - "bulk": 1 - } + "snmp": {"version": 2, "community": "public", "bulk": 1}, } } @@ -37,16 +35,13 @@ class TestZabbixInterface(unittest.TestCase): "authpassphrase": "authpass123", "privprotocol": "AES", "privpassphrase": "privpass123", - "contextname": "context1" - } + "contextname": "context1", + }, } } self.agent_context = { - "zabbix": { - "interface_type": 1, - "interface_port": "10050" - } + "zabbix": {"interface_type": 1, "interface_port": "10050"} } def test_init(self): @@ -95,27 +90,27 @@ class TestZabbixInterface(unittest.TestCase): # Test for agent type (1) interface.interface["type"] = 1 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "10050") # Test for SNMP type (2) interface.interface["type"] = 2 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "161") # Test for IPMI type (3) interface.interface["type"] = 3 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "623") # Test for JMX type (4) interface.interface["type"] = 4 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "12345") # Test for unsupported type interface.interface["type"] = 99 - result = interface._set_default_port() # pylint: disable=protected-access + result = interface._set_default_port() # pylint: disable=protected-access self.assertFalse(result) def test_set_snmp_v2(self): @@ -144,9 +139,13 @@ class TestZabbixInterface(unittest.TestCase): self.assertEqual(interface.interface["details"]["securityname"], "snmpuser") self.assertEqual(interface.interface["details"]["securitylevel"], "authPriv") self.assertEqual(interface.interface["details"]["authprotocol"], "SHA") - self.assertEqual(interface.interface["details"]["authpassphrase"], "authpass123") + self.assertEqual( + interface.interface["details"]["authpassphrase"], "authpass123" + ) self.assertEqual(interface.interface["details"]["privprotocol"], "AES") - self.assertEqual(interface.interface["details"]["privpassphrase"], "privpass123") + self.assertEqual( + interface.interface["details"]["privpassphrase"], "privpass123" + ) self.assertEqual(interface.interface["details"]["contextname"], "context1") def test_set_snmp_no_snmp_config(self): @@ -164,12 +163,7 @@ class TestZabbixInterface(unittest.TestCase): """Test set_snmp with unsupported SNMP version.""" # Create context with invalid SNMP version context = { - "zabbix": { - "interface_type": 2, - "snmp": { - "version": 4 # Invalid version - } - } + "zabbix": {"interface_type": 2, "snmp": {"version": 4}} # Invalid version } interface = ZabbixInterface(context, self.test_ip) interface.get_context() # Set the interface type @@ -184,9 +178,7 @@ class TestZabbixInterface(unittest.TestCase): context = { "zabbix": { "interface_type": 2, - "snmp": { - "community": "public" # No version specified - } + "snmp": {"community": "public"}, # No version specified } } interface = ZabbixInterface(context, self.test_ip) @@ -214,7 +206,9 @@ class TestZabbixInterface(unittest.TestCase): self.assertEqual(interface.interface["type"], "2") self.assertEqual(interface.interface["port"], "161") self.assertEqual(interface.interface["details"]["version"], "2") - self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") + self.assertEqual( + interface.interface["details"]["community"], "{$SNMP_COMMUNITY}" + ) self.assertEqual(interface.interface["details"]["bulk"], "1") def test_set_default_agent(self): @@ -229,14 +223,7 @@ class TestZabbixInterface(unittest.TestCase): def test_snmpv2_no_community(self): """Test SNMPv2 with no community string specified.""" # Create context with SNMPv2 but no community - context = { - "zabbix": { - "interface_type": 2, - "snmp": { - "version": 2 - } - } - } + context = {"zabbix": {"interface_type": 2, "snmp": {"version": 2}}} interface = ZabbixInterface(context, self.test_ip) interface.get_context() # Set the interface type @@ -244,4 +231,6 @@ class TestZabbixInterface(unittest.TestCase): interface.set_snmp() # Should use default community string - self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") + self.assertEqual( + interface.interface["details"]["community"], "{$SNMP_COMMUNITY}" + ) diff --git a/tests/test_physical_device.py b/tests/test_physical_device.py index 1b79ad8..dcf2dfa 100644 --- a/tests/test_physical_device.py +++ b/tests/test_physical_device.py @@ -1,8 +1,10 @@ """Tests for the PhysicalDevice class in the device module.""" + import unittest from unittest.mock import MagicMock, patch + from modules.device import PhysicalDevice -from modules.exceptions import TemplateError, SyncInventoryError +from modules.exceptions import SyncInventoryError, TemplateError class TestPhysicalDevice(unittest.TestCase): @@ -34,24 +36,27 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_logger = MagicMock() # Create PhysicalDevice instance with mocks - with patch('modules.device.config', - {"device_cf": "zabbix_hostid", - "template_cf": "zabbix_template", - "templates_config_context": False, - "templates_config_context_overrule": False, - "traverse_regions": False, - "traverse_site_groups": False, - "inventory_mode": "disabled", - "inventory_sync": False, - "device_inventory_map": {} - }): + with patch( + "modules.device.config", + { + "device_cf": "zabbix_hostid", + "template_cf": "zabbix_template", + "templates_config_context": False, + "templates_config_context_overrule": False, + "traverse_regions": False, + "traverse_site_groups": False, + "inventory_mode": "disabled", + "inventory_sync": False, + "device_inventory_map": {}, + }, + ): self.device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=True, - logger=self.mock_logger + logger=self.mock_logger, ) def test_init(self): @@ -69,14 +74,14 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.primary_ip = None # Creating device should raise SyncInventoryError - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): with self.assertRaises(SyncInventoryError): PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) def test_set_basics_with_special_characters(self): @@ -86,8 +91,9 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.name = "test-devïce" # We need to patch the search function to simulate finding special characters - with patch('modules.device.search') as mock_search, \ - patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.search") as mock_search, patch( + "modules.device.config", {"device_cf": "zabbix_hostid"} + ): # Make the search function return True to simulate special characters mock_search.return_value = True @@ -96,7 +102,7 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # With the mocked search function, the name should be changed to NETBOX_ID format @@ -110,19 +116,17 @@ class TestPhysicalDevice(unittest.TestCase): """Test get_templates_context with valid config.""" # Set up config_context with valid template data self.mock_nb_device.config_context = { - "zabbix": { - "templates": ["Template1", "Template2"] - } + "zabbix": {"templates": ["Template1", "Template2"]} } # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that templates are returned correctly @@ -132,20 +136,16 @@ class TestPhysicalDevice(unittest.TestCase): def test_get_templates_context_with_string(self): """Test get_templates_context with a string instead of list.""" # Set up config_context with a string template - self.mock_nb_device.config_context = { - "zabbix": { - "templates": "Template1" - } - } + self.mock_nb_device.config_context = {"zabbix": {"templates": "Template1"}} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that template is wrapped in a list @@ -158,13 +158,13 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.config_context = {} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that TemplateError is raised @@ -177,13 +177,13 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.config_context = {"zabbix": {}} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that TemplateError is raised @@ -193,25 +193,25 @@ class TestPhysicalDevice(unittest.TestCase): def test_set_template_with_config_context(self): """Test set_template with templates_config_context=True.""" # Set up config_context with templates - self.mock_nb_device.config_context = { - "zabbix": { - "templates": ["Template1"] - } - } + self.mock_nb_device.config_context = {"zabbix": {"templates": ["Template1"]}} # Mock get_templates_context to return expected templates - with patch.object(PhysicalDevice, 'get_templates_context', return_value=["Template1"]): - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch.object( + PhysicalDevice, "get_templates_context", return_value=["Template1"] + ): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_template with prefer_config_context=True - result = device.set_template(prefer_config_context=True, overrule_custom=False) + result = device.set_template( + prefer_config_context=True, overrule_custom=False + ) # Check result and template names self.assertTrue(result) @@ -223,20 +223,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "disabled", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -250,20 +250,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "manual", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -276,20 +276,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "automatic", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -303,38 +303,31 @@ class TestPhysicalDevice(unittest.TestCase): "device_cf": "zabbix_hostid", "inventory_mode": "manual", "inventory_sync": True, - "device_inventory_map": { - "name": "name", - "serial": "serialno_a" - } + "device_inventory_map": {"name": "name", "serial": "serialno_a"}, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Create a mock device with the required attributes - mock_device_data = { - "name": "test-device", - "serial": "ABC123" - } + mock_device_data = {"name": "test-device", "serial": "ABC123"} # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory(mock_device_data) # Check result self.assertTrue(result) self.assertEqual(device.inventory_mode, 0) # Manual mode - self.assertEqual(device.inventory, { - "name": "test-device", - "serialno_a": "ABC123" - }) + self.assertEqual( + device.inventory, {"name": "test-device", "serialno_a": "ABC123"} + ) def test_iscluster_true(self): """Test isCluster when device is part of a cluster.""" @@ -342,13 +335,13 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.virtual_chassis = MagicMock() # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Check isCluster result @@ -360,26 +353,27 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.virtual_chassis = None # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Check isCluster result self.assertFalse(device.isCluster()) - def test_promote_master_device_primary(self): """Test promoteMasterDevice when device is primary in cluster.""" # Set up virtual chassis with master device mock_vc = MagicMock() mock_vc.name = "virtual-chassis-1" mock_master = MagicMock() - mock_master.id = self.mock_nb_device.id # Set master ID to match the current device + mock_master.id = ( + self.mock_nb_device.id + ) # Set master ID to match the current device mock_vc.master = mock_master self.mock_nb_device.virtual_chassis = mock_vc @@ -389,7 +383,7 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call promoteMasterDevice and check the result @@ -400,14 +394,15 @@ class TestPhysicalDevice(unittest.TestCase): # Device name should be updated to virtual chassis name self.assertEqual(device.name, "virtual-chassis-1") - def test_promote_master_device_secondary(self): """Test promoteMasterDevice when device is secondary in cluster.""" # Set up virtual chassis with a different master device mock_vc = MagicMock() mock_vc.name = "virtual-chassis-1" mock_master = MagicMock() - mock_master.id = self.mock_nb_device.id + 1 # Different ID than the current device + mock_master.id = ( + self.mock_nb_device.id + 1 + ) # Different ID than the current device mock_vc.master = mock_master self.mock_nb_device.virtual_chassis = mock_vc @@ -417,7 +412,7 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call promoteMasterDevice and check the result diff --git a/tests/test_tools.py b/tests/test_tools.py index 3e6ae24..5361743 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,5 +1,6 @@ from modules.tools import sanatize_log_output + def test_sanatize_log_output_secrets(): data = { "macros": [ @@ -11,6 +12,7 @@ def test_sanatize_log_output_secrets(): assert sanitized["macros"][0]["value"] == "********" assert sanitized["macros"][1]["value"] == "notsecret" + def test_sanatize_log_output_interface_secrets(): data = { "interfaceid": 123, @@ -19,8 +21,8 @@ def test_sanatize_log_output_interface_secrets(): "privpassphrase": "anothersecret", "securityname": "sensitiveuser", "community": "public", - "other": "normalvalue" - } + "other": "normalvalue", + }, } sanitized = sanatize_log_output(data) # Sensitive fields should be sanitized @@ -33,6 +35,7 @@ def test_sanatize_log_output_interface_secrets(): # interfaceid should be removed assert "interfaceid" not in sanitized + def test_sanatize_log_output_interface_macros(): data = { "interfaceid": 123, @@ -41,7 +44,7 @@ def test_sanatize_log_output_interface_macros(): "privpassphrase": "{$SECRET_MACRO}", "securityname": "{$USER_MACRO}", "community": "{$SNNMP_COMMUNITY}", - } + }, } sanitized = sanatize_log_output(data) # Macro values should not be sanitized @@ -51,11 +54,13 @@ def test_sanatize_log_output_interface_macros(): assert sanitized["details"]["community"] == "{$SNNMP_COMMUNITY}" assert "interfaceid" not in sanitized + def test_sanatize_log_output_plain_data(): data = {"foo": "bar", "baz": 123} sanitized = sanatize_log_output(data) assert sanitized == data + def test_sanatize_log_output_non_dict(): data = [1, 2, 3] sanitized = sanatize_log_output(data) diff --git a/tests/test_usermacros.py b/tests/test_usermacros.py index 28305af..164a370 100644 --- a/tests/test_usermacros.py +++ b/tests/test_usermacros.py @@ -1,8 +1,10 @@ import unittest from unittest.mock import MagicMock, patch + from modules.device import PhysicalDevice from modules.usermacros import ZabbixUsermacros + class DummyNB: def __init__(self, name="dummy", config_context=None, **kwargs): self.name = name @@ -18,6 +20,7 @@ class DummyNB: return self.config_context[key] raise KeyError(key) + class TestUsermacroSync(unittest.TestCase): def setUp(self): self.nb = DummyNB(serial="1234") @@ -58,6 +61,7 @@ class TestUsermacroSync(unittest.TestCase): self.assertIsInstance(device.usermacros, list) self.assertGreater(len(device.usermacros), 0) + class TestZabbixUsermacros(unittest.TestCase): def setUp(self): self.nb = DummyNB() @@ -78,7 +82,9 @@ class TestZabbixUsermacros(unittest.TestCase): def test_render_macro_dict(self): macros = ZabbixUsermacros(self.nb, {}, False, logger=self.logger) - macro = macros.render_macro("{$FOO}", {"value": "bar", "type": "secret", "description": "desc"}) + macro = macros.render_macro( + "{$FOO}", {"value": "bar", "type": "secret", "description": "desc"} + ) self.assertEqual(macro["macro"], "{$FOO}") self.assertEqual(macro["value"], "bar") self.assertEqual(macro["type"], "1") @@ -121,5 +127,6 @@ class TestZabbixUsermacros(unittest.TestCase): self.assertEqual(len(result), 1) self.assertEqual(result[0]["macro"], "{$FOO}") + if __name__ == "__main__": unittest.main()