diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 81c3292..402443d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,7 +3,7 @@ { "name": "Python 3", // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye", + "image": "mcr.microsoft.com/devcontainers/python:3.14", // Features to add to the dev container. More info: https://containers.dev/features. // "features": {}, @@ -12,7 +12,7 @@ // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "pip3 install --user -r requirements.txt && pip3 install --user pylint pytest coverage pytest-cov" + "postCreateCommand": "pip3 install --user -r requirements.txt && pip3 install --user uv pylint pytest coverage pytest-cov && uv sync --dev" // Configure tool-specific properties. // "customizations": {}, diff --git a/README.md b/README.md index 4c762ed..a88113d 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,10 @@ A script to create, update and delete Zabbix hosts using NetBox device objects. Tested and compatible with all [currently supported Zabbix releases](https://www.zabbix.com/life_cycle_and_release_policy). +# Documentation + +Documentation will be moved to the Github wiki of this project. Feel free to [check it out](https://github.com/TheNetworkGuy/netbox-zabbix-sync/wiki)! + ## Installation via Docker To pull the latest stable version to your local cache, use the following docker diff --git a/modules/config.py b/modules/config.py index 3d69d6c..e5509c6 100644 --- a/modules/config.py +++ b/modules/config.py @@ -2,10 +2,10 @@ Module for parsing configuration from the top level config.py file """ -from pathlib import Path from importlib import util -from os import environ, path from logging import getLogger +from os import environ, path +from pathlib import Path logger = getLogger(__name__) @@ -123,6 +123,8 @@ def load_config_file(config_default, config_file="config.py"): dconf = config_default.copy() # Dynamically import the config module spec = util.spec_from_file_location("config", config_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load config from {config_path}") config_module = util.module_from_spec(spec) spec.loader.exec_module(config_module) # Update DEFAULT_CONFIG with variables from the config module diff --git a/modules/device.py b/modules/device.py index ab1f38f..71cb6dd 100644 --- a/modules/device.py +++ b/modules/device.py @@ -67,7 +67,7 @@ class PhysicalDevice: self.usermacros = [] self.tags = {} self.logger = logger if logger else getLogger(__name__) - self._setBasics() + self._set_basics() def __repr__(self): return self.name @@ -87,7 +87,7 @@ class PhysicalDevice: """Use device host tag maps""" return config["device_tag_map"] - def _setBasics(self): + def _set_basics(self): """ Sets basic information like IP address. """ @@ -167,7 +167,7 @@ class PhysicalDevice: # Gather templates from the custom field but overrule # them should there be any device specific templates if overrule_custom: - try: + try: # noqa: SIM105 self.zbx_template_names = self.get_templates_context() except TemplateError: pass @@ -247,17 +247,17 @@ class PhysicalDevice: ) return True - def isCluster(self): + def is_cluster(self): """ Checks if device is part of cluster. """ return bool(self.nb.virtual_chassis) - def getClusterMaster(self): + def get_cluster_master(self): """ Returns chassis master ID. """ - if not self.isCluster(): + if not self.is_cluster(): e = ( f"Unable to proces {self.name} for cluster calculation: " f"not part of a cluster." @@ -273,13 +273,13 @@ class PhysicalDevice: raise SyncInventoryError(e) return self.nb.virtual_chassis.master.id - def promoteMasterDevice(self): + def promote_primary_device(self): """ If device is Primary in cluster, promote device name to the cluster name. Returns True if succesfull, returns False if device is secondary. """ - masterid = self.getClusterMaster() + masterid = self.get_cluster_master() if masterid == self.id: self.logger.info( "Host %s is primary cluster member. Modifying hostname from %s to %s.", @@ -292,7 +292,7 @@ class PhysicalDevice: self.logger.info("Host %s is non-primary cluster member.", self.name) return False - def zbxTemplatePrepper(self, templates): + def zbx_template_prepper(self, templates): """ Returns Zabbix template IDs INPUT: list of templates from Zabbix @@ -335,7 +335,7 @@ class PhysicalDevice: self.logger.warning(e) raise SyncInventoryError(e) - def setZabbixGroupID(self, groups): + def set_zbx_groupid(self, groups): """ Sets Zabbix group ID as instance variable INPUT: list of hostgroups @@ -351,9 +351,7 @@ class PhysicalDevice: f'"{group["name"]}" (ID:{group["groupid"]})' ) self.logger.debug(e) - if len(self.group_ids) == len(self.hostgroups): - return True - return False + return len(self.group_ids) == len(self.hostgroups) def cleanup(self): """ @@ -378,7 +376,7 @@ class PhysicalDevice: self.logger.info(e) self.create_journal_entry("warning", "Deleted host from Zabbix") except APIRequestError as e: - message = f"Zabbix returned the following error: {str(e)}." + message = f"Zabbix returned the following error: {e}." self.logger.error(message) raise SyncExternalError(message) from e @@ -388,7 +386,7 @@ class PhysicalDevice: self.nb.custom_fields[config["device_cf"]] = None self.nb.save() - def _zabbixHostnameExists(self): + def _zabbix_hostname_exists(self): """ Checks if hostname exists in Zabbix. """ @@ -400,7 +398,7 @@ class PhysicalDevice: host = self.zabbix.host.get(filter=zbx_filter, output=[]) return bool(host) - def setInterfaceDetails(self): + def set_interface_details(self): """ Checks interface parameters from NetBox and creates a model for the interface to be used in Zabbix. @@ -412,7 +410,8 @@ class PhysicalDevice: # If not fall back to old config. if interface.get_context(): # If device is SNMP type, add aditional information. - if interface.interface["type"] == 2: + snmp_interface_type = 2 + if interface.interface["type"] == snmp_interface_type: interface.set_snmp() else: interface.set_default_snmp() @@ -460,7 +459,7 @@ class PhysicalDevice: self.tags = tags.generate() return True - def _setProxy(self, proxy_list: list[dict[str, Any]]) -> bool: + def _set_proxy(self, proxy_list: list[dict[str, Any]]) -> bool: """ Sets proxy or proxy group if this value has been defined in config context @@ -474,7 +473,9 @@ class PhysicalDevice: proxy_types = ["proxy"] proxy_name = None - if self.zabbix.version >= 7.0: + zabbix_7_version = 7.0 + + if self.zabbix.version >= zabbix_7_version: # Only insert groups in front of list for Zabbix7 proxy_types.insert(0, "proxy_group") @@ -510,7 +511,7 @@ class PhysicalDevice: if proxy_name: for proxy in proxy_list: # If the proxy does not match the type, ignore and continue - if not proxy["type"] == proxy_type: + if proxy["type"] != proxy_type: continue # If the proxy name matches if proxy["name"] == proxy_name: @@ -525,7 +526,7 @@ class PhysicalDevice: ) return False - def createInZabbix( + def create_in_zabbix( self, groups, templates, @@ -536,23 +537,23 @@ class PhysicalDevice: Creates Zabbix host object with parameters from NetBox object. """ # Check if hostname is already present in Zabbix - if not self._zabbixHostnameExists(): + if not self._zabbix_hostname_exists(): # Set group and template ID's for host - if not self.setZabbixGroupID(groups): + if not self.set_zbx_groupid(groups): e = ( - f"Unable to find group '{self.hostgroup}' " + f"Unable to find group '{self.hostgroups}' " f"for host {self.name} in Zabbix." ) self.logger.warning(e) raise SyncInventoryError(e) - self.zbxTemplatePrepper(templates) + self.zbx_template_prepper(templates) templateids = [] for template in self.zbx_templates: templateids.append({"templateid": template["templateid"]}) # Set interface, group and template configuration - interfaces = self.setInterfaceDetails() + interfaces = self.set_interface_details() # Set Zabbix proxy if defined - self._setProxy(proxies) + self._set_proxy(proxies) # Set basic data for host creation create_data = { "host": self.name, @@ -582,7 +583,7 @@ class PhysicalDevice: host = self.zabbix.host.create(**create_data) self.zabbix_id = host["hostids"][0] except APIRequestError as e: - msg = f"Host {self.name}: Couldn't create. Zabbix returned {str(e)}." + msg = f"Host {self.name}: Couldn't create. Zabbix returned {e}." self.logger.error(msg) raise SyncExternalError(msg) from e # Set NetBox custom field to hostID value. @@ -596,7 +597,7 @@ class PhysicalDevice: "Host %s: Unable to add to Zabbix. Host already present.", self.name ) - def createZabbixHostgroup(self, hostgroups): + def create_zbx_hostgroup(self, hostgroups): """ Creates Zabbix host group based on hostgroup format. Creates multiple when using a nested format. @@ -606,7 +607,7 @@ class PhysicalDevice: for hostgroup in self.hostgroups: for pos in range(len(hostgroup.split("/"))): zabbix_hg = hostgroup.rsplit("/", pos)[0] - if self.lookupZabbixHostgroup(hostgroups, zabbix_hg): + if self.zbx_hostgroup_lookup(hostgroups, zabbix_hg): # Hostgroup already exists continue # Create new group @@ -620,24 +621,21 @@ class PhysicalDevice: {"groupid": groupid["groupids"][0], "name": zabbix_hg} ) except APIRequestError as e: - msg = f"Hostgroup '{zabbix_hg}': unable to create. Zabbix returned {str(e)}." + msg = f"Hostgroup '{zabbix_hg}': unable to create. Zabbix returned {e}." self.logger.error(msg) raise SyncExternalError(msg) from e return final_data - def lookupZabbixHostgroup(self, group_list, lookup_group): + def zbx_hostgroup_lookup(self, group_list, lookup_group): """ Function to check if a hostgroup exists in a list of Zabbix hostgroups INPUT: Group list and group lookup OUTPUT: Boolean """ - for group in group_list: - if group["name"] == lookup_group: - return True - return False + return any(group["name"] == lookup_group for group in group_list) - def updateZabbixHost(self, **kwargs): + def update_zabbix_host(self, **kwargs): """ Updates Zabbix host with given parameters. INPUT: Key word arguments for Zabbix host object. @@ -647,7 +645,7 @@ class PhysicalDevice: except APIRequestError as e: e = ( f"Host {self.name}: Unable to update. " - f"Zabbix returned the following error: {str(e)}." + f"Zabbix returned the following error: {e}." ) self.logger.error(e) raise SyncExternalError(e) from None @@ -656,7 +654,7 @@ class PhysicalDevice: ) self.create_journal_entry("info", "Updated host in Zabbix with latest NB data.") - def ConsistencyCheck( + def consistency_check( self, groups, templates, proxies, proxy_power, create_hostgroups ): # pylint: disable=too-many-branches, too-many-statements @@ -664,17 +662,18 @@ class PhysicalDevice: Checks if Zabbix object is still valid with NetBox parameters. """ # If group is found or if the hostgroup is nested - if not self.setZabbixGroupID(groups): # or len(self.hostgroups.split("/")) > 1: + if not self.set_zbx_groupid(groups): # or len(self.hostgroups.split("/")) > 1: if create_hostgroups: # Script is allowed to create a new hostgroup - new_groups = self.createZabbixHostgroup(groups) + new_groups = self.create_zbx_hostgroup(groups) for group in new_groups: # Add all new groups to the list of groups groups.append(group) # check if the initial group was not already found (and this is a nested folder check) if not self.group_ids: - # Function returns true / false but also sets GroupID - if not self.setZabbixGroupID(groups) and not create_hostgroups: + zbx_groupid_confirmation = self.set_zbx_groupid(groups) + if not zbx_groupid_confirmation and not create_hostgroups: + # Function returns true / false but also sets GroupID e = ( f"Host {self.name}: different hostgroup is required but " "unable to create hostgroup without generation permission." @@ -683,8 +682,8 @@ class PhysicalDevice: raise SyncInventoryError(e) # Prepare templates and proxy config - self.zbxTemplatePrepper(templates) - self._setProxy(proxies) + self.zbx_template_prepper(templates) + self._set_proxy(proxies) # Get host object from Zabbix host = self.zabbix.host.get( filter={"hostid": self.zabbix_id}, @@ -720,7 +719,7 @@ class PhysicalDevice: self.name, host["host"], ) - self.updateZabbixHost(host=self.name) + self.update_zabbix_host(host=self.name) # Execute check depending on wether the name is special or not if self.use_visible_name: @@ -732,7 +731,7 @@ class PhysicalDevice: self.name, host["name"], ) - self.updateZabbixHost(name=self.visible_name) + self.update_zabbix_host(name=self.visible_name) # Check if the templates are in-sync if not self.zbx_template_comparer(host["parentTemplates"]): @@ -742,7 +741,7 @@ class PhysicalDevice: for template in self.zbx_templates: templateids.append({"templateid": template["templateid"]}) # Update Zabbix with NB templates and clear any old / lost templates - self.updateZabbixHost( + self.update_zabbix_host( templates_clear=host["parentTemplates"], templates=templateids ) else: @@ -759,31 +758,31 @@ class PhysicalDevice: self.logger.debug("Host %s: Hostgroups in-sync.", self.name) else: self.logger.info("Host %s: Hostgroups OUT of sync.", self.name) - self.updateZabbixHost(groups=self.group_ids) + self.update_zabbix_host(groups=self.group_ids) if int(host["status"]) == self.zabbix_state: self.logger.debug("Host %s: Status in-sync.", self.name) else: self.logger.info("Host %s: Status OUT of sync.", self.name) - self.updateZabbixHost(status=str(self.zabbix_state)) + self.update_zabbix_host(status=str(self.zabbix_state)) # Check if a proxy has been defined if self.zbxproxy: - # Check if proxy or proxy group is defined + # Check if proxy or proxy group is defined. + # Check for proxy_hostid for backwards compatibility with Zabbix <= 6 if ( self.zbxproxy["idtype"] in host and host[self.zbxproxy["idtype"]] == self.zbxproxy["id"] + ) or ( + "proxy_hostid" in host and host["proxy_hostid"] == self.zbxproxy["id"] ): self.logger.debug("Host %s: Proxy in-sync.", self.name) - # Backwards compatibility for Zabbix <= 6 - elif "proxy_hostid" in host and host["proxy_hostid"] == self.zbxproxy["id"]: - self.logger.debug("Host %s: Proxy in-sync.", self.name) # Proxy does not match, update Zabbix else: self.logger.info("Host %s: Proxy OUT of sync.", self.name) # Zabbix <= 6 patch if not str(self.zabbix.version).startswith("7"): - self.updateZabbixHost(proxy_hostid=self.zbxproxy["id"]) + self.update_zabbix_host(proxy_hostid=self.zbxproxy["id"]) # Zabbix 7+ else: # Prepare data structure for updating either proxy or group @@ -791,15 +790,14 @@ class PhysicalDevice: self.zbxproxy["idtype"]: self.zbxproxy["id"], "monitored_by": self.zbxproxy["monitored_by"], } - self.updateZabbixHost(**update_data) + self.update_zabbix_host(**update_data) else: # No proxy is defined in NetBox proxy_set = False # Check if a proxy is defined. Uses the proxy_hostid key for backwards compatibility for key in ("proxy_hostid", "proxyid", "proxy_groupid"): - if key in host: - if bool(int(host[key])): - proxy_set = True + if key in host and bool(int(host[key])): + proxy_set = True if proxy_power and proxy_set: # Zabbix <= 6 fix self.logger.warning( @@ -808,13 +806,13 @@ class PhysicalDevice: self.name, ) if "proxy_hostid" in host and bool(host["proxy_hostid"]): - self.updateZabbixHost(proxy_hostid=0) + self.update_zabbix_host(proxy_hostid=0) # Zabbix 7 proxy elif "proxyid" in host and bool(host["proxyid"]): - self.updateZabbixHost(proxyid=0, monitored_by=0) + self.update_zabbix_host(proxyid=0, monitored_by=0) # Zabbix 7 proxy group elif "proxy_groupid" in host and bool(host["proxy_groupid"]): - self.updateZabbixHost(proxy_groupid=0, monitored_by=0) + self.update_zabbix_host(proxy_groupid=0, monitored_by=0) # Checks if a proxy has been defined in Zabbix and if proxy_power config has been set if proxy_set and not proxy_power: # Display error message @@ -830,14 +828,14 @@ class PhysicalDevice: self.logger.debug("Host %s: inventory_mode in-sync.", self.name) else: self.logger.info("Host %s: inventory_mode OUT of sync.", self.name) - self.updateZabbixHost(inventory_mode=str(self.inventory_mode)) + self.update_zabbix_host(inventory_mode=str(self.inventory_mode)) if config["inventory_sync"] and self.inventory_mode in [0, 1]: # Check host inventory mapping if host["inventory"] == self.inventory: self.logger.debug("Host %s: Inventory in-sync.", self.name) else: self.logger.info("Host %s: Inventory OUT of sync.", self.name) - self.updateZabbixHost(inventory=self.inventory) + self.update_zabbix_host(inventory=self.inventory) # Check host usermacros if config["usermacro_sync"]: @@ -865,7 +863,7 @@ class PhysicalDevice: else: self.logger.info("Host %s: Usermacros OUT of sync.", self.name) # Update Zabbix with NetBox usermacros - self.updateZabbixHost(macros=self.usermacros) + self.update_zabbix_host(macros=self.usermacros) # Check host tags if config["tag_sync"]: @@ -877,37 +875,37 @@ class PhysicalDevice: self.logger.debug("Host %s: Tags in-sync.", self.name) else: self.logger.info("Host %s: Tags OUT of sync.", self.name) - self.updateZabbixHost(tags=self.tags) + self.update_zabbix_host(tags=self.tags) # If only 1 interface has been found # pylint: disable=too-many-nested-blocks if len(host["interfaces"]) == 1: updates = {} # Go through each key / item and check if it matches Zabbix - for key, item in self.setInterfaceDetails()[0].items(): + for key, item in self.set_interface_details()[0].items(): # Check if NetBox value is found in Zabbix if key in host["interfaces"][0]: # If SNMP is used, go through nested dict # to compare SNMP parameters if isinstance(item, dict) and key == "details": for k, i in item.items(): - if k in host["interfaces"][0][key]: - # Set update if values don't match - if host["interfaces"][0][key][k] != str(i): - # If dict has not been created, add it - if key not in updates: - updates[key] = {} - updates[key][k] = str(i) - # If SNMP version has been changed - # break loop and force full SNMP update - if k == "version": - break + # Check if the key is found in Zabbix and if the value matches + if k in host["interfaces"][0][key] and host["interfaces"][ + 0 + ][key][k] != str(i): + # If dict has not been created, add it + if key not in updates: + updates[key] = {} + updates[key][k] = str(i) + # If SNMP version has been changed + # break loop and force full SNMP update + if k == "version": + break # Force full SNMP config update # when version has changed. - if key in updates: - if "version" in updates[key]: - for k, i in item.items(): - updates[key][k] = str(i) + if key in updates and "version" in updates[key]: + for k, i in item.items(): + updates[key][k] = str(i) continue # Set update if values don't match if host["interfaces"][0][key] != str(item): @@ -919,7 +917,7 @@ class PhysicalDevice: # Changing interface type not supported. Raise exception. e = ( f"Host {self.name}: Changing interface type to " - f"{str(updates['type'])} is not supported." + f"{updates['type']} is not supported." ) self.logger.error(e) raise InterfaceConfigError(e) @@ -935,7 +933,7 @@ class PhysicalDevice: self.logger.info(err_msg) self.create_journal_entry("info", err_msg) except APIRequestError as e: - msg = f"Zabbix returned the following error: {str(e)}." + msg = f"Zabbix returned the following error: {e}." self.logger.error(msg) raise SyncExternalError(msg) from e else: @@ -1006,12 +1004,11 @@ class PhysicalDevice: nb_tmpl["name"], ) break - if ( + # The following condition is only true if: + # all of the NetBox templates have been confirmed as successful + # and the ZBX template list is empty. This means that + # all of the templates match. + return ( len(succesfull_templates) == len(self.zbx_templates) and len(tmpls_from_zabbix) == 0 - ): - # All of the NetBox templates have been confirmed as successfull - # and the ZBX template list is empty. This means that - # all of the templates match. - return True - return False + ) diff --git a/modules/exceptions.py b/modules/exceptions.py index ddac2b0..ccee1af 100644 --- a/modules/exceptions.py +++ b/modules/exceptions.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 """ All custom exceptions used for Exception generation """ diff --git a/modules/hostgroups.py b/modules/hostgroups.py index 49890e6..d5544be 100644 --- a/modules/hostgroups.py +++ b/modules/hostgroups.py @@ -26,7 +26,7 @@ class Hostgroup: self.logger = logger if logger else getLogger(__name__) if obj_type not in ("vm", "dev"): msg = f"Unable to create hostgroup with type {type}" - self.logger.error() + self.logger.error(msg) raise HostgroupError(msg) self.type = str(obj_type) self.nb = nb_obj @@ -87,12 +87,10 @@ class Hostgroup: str(self.nb.location) if self.nb.location else None ) format_options["rack"] = self.nb.rack.name if self.nb.rack else None - # Variables only applicable for VM's - if self.type == "vm": - # Check if a cluster is configured. Could also be configured in a site. - if self.nb.cluster: - format_options["cluster"] = self.nb.cluster.name - format_options["cluster_type"] = self.nb.cluster.type.name + # Variables only applicable for VM's such as clusters + if self.type == "vm" and self.nb.cluster: + format_options["cluster"] = self.nb.cluster.name + format_options["cluster_type"] = self.nb.cluster.type.name self.format_options = format_options self.logger.debug( "Host %s: Resolved properties for use in hostgroups: %s", @@ -117,10 +115,14 @@ class Hostgroup: for hg_item in hg_items: # Check if requested data is available as option for this host if hg_item not in self.format_options: - if hg_item.startswith(("'", '"')) and hg_item.endswith(("'", '"')): - hg_item = hg_item.strip("'") - hg_item = hg_item.strip('"') - hg_output.append(hg_item) + # If the string is between quotes, use it as a literal in the hostgroup name + minimum_length = 2 + if ( + len(hg_item) > minimum_length + and hg_item[0] == hg_item[-1] + and hg_item[0] in ("'", '"') + ): + hg_output.append(hg_item[1:-1]) else: # Check if a custom field exists with this name cf_data = self.custom_field_lookup(hg_item) @@ -155,20 +157,6 @@ class Hostgroup: self.logger.warning(msg) return None - def list_formatoptions(self): - """ - Function to easily troubleshoot which values - are generated for a specific device or VM. - """ - print(f"The following options are available for host {self.name}") - for option_type, value in self.format_options.items(): - if value is not None: - print(f"{option_type} - {value}") - print("The following options are not available") - for option_type, value in self.format_options.items(): - if value is None: - print(f"{option_type}") - def custom_field_lookup(self, hg_category): """ Checks if a valid custom field is present in NetBox. @@ -192,7 +180,7 @@ class Hostgroup: OUTPUT: STRING - Either the single child name or child and parents. """ # Check if this type of nesting is supported. - if not nest_type in self.nested_objects: + if nest_type not in self.nested_objects: return child_object # If the nested flag is True, perform parent calculation if self.nested_objects[nest_type]["flag"]: diff --git a/modules/interface.py b/modules/interface.py index 1bd1e37..4b79134 100644 --- a/modules/interface.py +++ b/modules/interface.py @@ -1,7 +1,7 @@ -#!/usr/bin/env python3 """ All of the Zabbix interface related configuration """ + from modules.exceptions import InterfaceConfigError @@ -30,7 +30,7 @@ class ZabbixInterface: zabbix = self.context["zabbix"] if "interface_type" in zabbix: self.interface["type"] = zabbix["interface_type"] - if not "interface_port" in zabbix: + if "interface_port" not in zabbix: self._set_default_port() return True self.interface["port"] = zabbix["interface_port"] @@ -41,35 +41,37 @@ class ZabbixInterface: def set_snmp(self): """Check if interface is type SNMP""" # pylint: disable=too-many-branches - if self.interface["type"] == 2: + snmp_interface_type = 2 + if self.interface["type"] == snmp_interface_type: # Checks if SNMP settings are defined in NetBox if "snmp" in self.context["zabbix"]: snmp = self.context["zabbix"]["snmp"] - self.interface["details"] = {} + details: dict[str, str] = {} + self.interface["details"] = details # Checks if bulk config has been defined if "bulk" in snmp: - self.interface["details"]["bulk"] = str(snmp.pop("bulk")) + details["bulk"] = str(snmp.pop("bulk")) else: # Fallback to bulk enabled if not specified - self.interface["details"]["bulk"] = "1" + details["bulk"] = "1" # SNMP Version config is required in NetBox config context if snmp.get("version"): - self.interface["details"]["version"] = str(snmp.pop("version")) + details["version"] = str(snmp.pop("version")) else: e = "SNMP version option is not defined." raise InterfaceConfigError(e) # If version 1 or 2 is used, get community string - if self.interface["details"]["version"] in ["1", "2"]: + if details["version"] in ["1", "2"]: if "community" in snmp: # Set SNMP community to confix context value community = snmp["community"] else: # Set SNMP community to default community = "{$SNMP_COMMUNITY}" - self.interface["details"]["community"] = str(community) + details["community"] = str(community) # If version 3 has been used, get all # SNMPv3 NetBox related configs - elif self.interface["details"]["version"] == "3": + elif details["version"] == "3": items = [ "securityname", "securitylevel", @@ -81,7 +83,7 @@ class ZabbixInterface: ] for key, item in snmp.items(): if key in items: - self.interface["details"][key] = str(item) + details[key] = str(item) else: e = "Unsupported SNMP version." raise InterfaceConfigError(e) diff --git a/modules/tags.py b/modules/tags.py index 835490c..21497f6 100644 --- a/modules/tags.py +++ b/modules/tags.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-positional-arguments, logging-fstring-interpolation """ All of the Zabbix Usermacro related configuration @@ -54,17 +53,23 @@ class ZabbixTags: """ Validates tag name """ - if tag_name and isinstance(tag_name, str) and len(tag_name) <= 256: - return True - return False + max_tag_name_length = 256 + return ( + tag_name + and isinstance(tag_name, str) + and len(tag_name) <= max_tag_name_length + ) def validate_value(self, tag_value): """ Validates tag value """ - if tag_value and isinstance(tag_value, str) and len(tag_value) <= 256: - return True - return False + max_tag_value_length = 256 + return ( + tag_value + and isinstance(tag_value, str) + and len(tag_value) <= max_tag_value_length + ) def render_tag(self, tag_name, tag_value): """ @@ -123,7 +128,11 @@ class ZabbixTags: # Pull in NetBox device tags if tag_name is set if self.tag_name and isinstance(self.tag_name, str): for tag in self.nb.tags: - if self.tag_value.lower() in ["display", "name", "slug"]: + if ( + self.tag_value + and isinstance(self.tag_value, str) + and self.tag_value.lower() in ["display", "name", "slug"] + ): value = tag[self.tag_value] else: value = tag["name"] diff --git a/modules/tools.py b/modules/tools.py index adacca2..a6b0a22 100644 --- a/modules/tools.py +++ b/modules/tools.py @@ -1,6 +1,8 @@ """A collection of tools used by several classes""" -from typing import Any, Callable, Optional, overload +from collections.abc import Callable +from typing import Any, cast, overload + from modules.exceptions import HostgroupError @@ -21,10 +23,14 @@ def build_path(endpoint, list_of_dicts): item_path = [] itemlist = [i for i in list_of_dicts if i["name"] == endpoint] item = itemlist[0] if len(itemlist) == 1 else None + if item is None: + return [] item_path.append(item["name"]) while item["_depth"] > 0: itemlist = [i for i in list_of_dicts if i["name"] == str(item["parent"])] item = itemlist[0] if len(itemlist) == 1 else None + if item is None: + break item_path.append(item["name"]) item_path.reverse() return item_path @@ -58,9 +64,10 @@ def cf_to_string(cf, key="name", logger=None): if isinstance(cf, dict): if key in cf: return cf[key] - logger.error( - "Conversion of custom field failed, '%s' not found in cf dict.", key - ) + if logger: + logger.error( + "Conversion of custom field failed, '%s' not found in cf dict.", key + ) return None return cf @@ -112,14 +119,14 @@ def field_mapper(host, mapper, nbdevice, logger): @overload def remove_duplicates( input_list: list[dict[Any, Any]], - sortkey: Optional[str | Callable[[dict[str, Any]], str]] = None, + sortkey: str | Callable[[dict[str, Any]], str] | None = None, ): ... @overload def remove_duplicates( input_list: dict[Any, Any], - sortkey: Optional[str | Callable[[dict[str, Any]], str]] = None, + sortkey: str | Callable[[dict[str, Any]], str] | None = None, ): """ deprecated: input_list as dict is deprecated, use list of dicts instead @@ -128,7 +135,7 @@ def remove_duplicates( def remove_duplicates( input_list: list[dict[Any, Any]] | dict[Any, Any], - sortkey: Optional[str | Callable[[dict[str, Any]], str]] = None, + sortkey: str | Callable[[dict[str, Any]], str] | None = None, ): """ Removes duplicate entries from a list and sorts the list @@ -143,7 +150,7 @@ def remove_duplicates( output_list.sort(key=lambda x: x[sortkey]) elif sortkey and callable(sortkey): - output_list.sort(key=sortkey) + output_list.sort(key=cast(Any, sortkey)) return output_list @@ -188,9 +195,9 @@ def verify_hg_format( "cfs": {"dev": [], "vm": []}, } for cf in device_cfs: - allowed_objects["cfs"]["dev"].append(cf.name) + allowed_objects["cfs"]["dev"].append(cf.name) # type: ignore[index] for cf in vm_cfs: - allowed_objects["cfs"]["vm"].append(cf.name) + allowed_objects["cfs"]["vm"].append(cf.name) # type: ignore[index] hg_objects = [] if isinstance(hg_format, list): for f in hg_format: @@ -201,14 +208,15 @@ def verify_hg_format( for hg_object in hg_objects: if ( hg_object not in allowed_objects[hg_type] - and hg_object not in allowed_objects["cfs"][hg_type] + and hg_object not in allowed_objects["cfs"][hg_type] # type: ignore[index] and not hg_object.startswith(('"', "'")) ): e = ( f"Hostgroup item {hg_object} is not valid. Make sure you" " use valid items and separate them with '/'." ) - logger.warning(e) + if logger: + logger.warning(e) raise HostgroupError(e) @@ -235,7 +243,7 @@ def sanatize_log_output(data): del sanitized_data["interfaceid"] # InterfaceID also hints that this is a interface update. # A check is required if there are no macro's used for SNMP security parameters. - if not "details" in data: + if "details" not in data: return sanitized_data for key, detail in sanitized_data["details"].items(): # If the detail is a secret, we don't want to log it. diff --git a/modules/usermacros.py b/modules/usermacros.py index acf8725..1a0780c 100644 --- a/modules/usermacros.py +++ b/modules/usermacros.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-positional-arguments, logging-fstring-interpolation """ All of the Zabbix Usermacro related configuration @@ -57,7 +56,7 @@ class ZabbixUsermacros: if self.validate_macro(macro_name): macro["macro"] = str(macro_name) if isinstance(macro_properties, dict): - if not "value" in macro_properties: + if "value" not in macro_properties: self.logger.info( "Host %s: Usermacro %s has no value in Netbox, skipping.", self.name, diff --git a/modules/virtual_machine.py b/modules/virtual_machine.py index 8c52033..ff4ed0c 100644 --- a/modules/virtual_machine.py +++ b/modules/virtual_machine.py @@ -1,9 +1,11 @@ # pylint: disable=duplicate-code """Module that hosts all functions for virtual machine processing""" + +from modules.config import load_config from modules.device import PhysicalDevice from modules.exceptions import InterfaceConfigError, SyncInventoryError, TemplateError from modules.interface import ZabbixInterface -from modules.config import load_config + # Load config config = load_config() @@ -39,11 +41,12 @@ class VirtualMachine(PhysicalDevice): self.logger.warning(e) return True - def setInterfaceDetails(self): # pylint: disable=invalid-name + def set_interface_details(self): """ Overwrites device function to select an agent interface type by default Agent type interfaces are more likely to be used with VMs then SNMP """ + zabbix_snmp_interface_type = 2 try: # Initiate interface class interface = ZabbixInterface(self.nb.config_context, self.ip) @@ -51,7 +54,7 @@ class VirtualMachine(PhysicalDevice): # If not fall back to old config. if interface.get_context(): # If device is SNMP type, add aditional information. - if interface.interface["type"] == 2: + if interface.interface["type"] == zabbix_snmp_interface_type: interface.set_snmp() else: interface.set_default_agent() diff --git a/netbox_zabbix_sync.py b/netbox_zabbix_sync.py index 8791e47..db4edc9 100755 --- a/netbox_zabbix_sync.py +++ b/netbox_zabbix_sync.py @@ -6,7 +6,8 @@ import argparse import logging import ssl -from os import environ, sys +import sys +from os import environ from pynetbox import api from pynetbox.core.query import RequestError as NBRequestError @@ -115,15 +116,12 @@ def main(arguments): else: zabbix = ZabbixAPI(zabbix_host, token=zabbix_token, ssl_context=ssl_ctx) zabbix.check_auth() - except (APIRequestError, ProcessingError) as e: - e = f"Zabbix returned the following error: {str(e)}" + except (APIRequestError, ProcessingError) as zbx_error: + e = f"Zabbix returned the following error: {zbx_error}." logger.error(e) sys.exit(1) # Set API parameter mapping based on API version - if not str(zabbix.version).startswith("7"): - proxy_name = "host" - else: - proxy_name = "name" + proxy_name = "host" if not str(zabbix.version).startswith("7") else "name" # Get all Zabbix and NetBox data netbox_devices = list(netbox.dcim.devices.filter(**config["nb_device_filter"])) netbox_vms = [] @@ -131,16 +129,16 @@ def main(arguments): netbox_vms = list( netbox.virtualization.virtual_machines.filter(**config["nb_vm_filter"]) ) - netbox_site_groups = convert_recordset((netbox.dcim.site_groups.all())) + netbox_site_groups = convert_recordset(netbox.dcim.site_groups.all()) netbox_regions = convert_recordset(netbox.dcim.regions.all()) netbox_journals = netbox.extras.journal_entries - zabbix_groups = zabbix.hostgroup.get(output=["groupid", "name"]) - zabbix_templates = zabbix.template.get(output=["templateid", "name"]) - zabbix_proxies = zabbix.proxy.get(output=["proxyid", proxy_name]) + zabbix_groups = zabbix.hostgroup.get(output=["groupid", "name"]) # type: ignore[attr-defined] + zabbix_templates = zabbix.template.get(output=["templateid", "name"]) # type: ignore[attr-defined] + zabbix_proxies = zabbix.proxy.get(output=["proxyid", proxy_name]) # type: ignore[attr-defined] # Set empty list for proxy processing Zabbix <= 6 zabbix_proxygroups = [] if str(zabbix.version).startswith("7"): - zabbix_proxygroups = zabbix.proxygroup.get(output=["proxy_groupid", "name"]) + zabbix_proxygroups = zabbix.proxygroup.get(output=["proxy_groupid", "name"]) # type: ignore[attr-defined] # Sanitize proxy data if proxy_name == "host": for proxy in zabbix_proxies: @@ -172,7 +170,7 @@ def main(arguments): continue if config["extended_site_properties"] and nb_vm.site: logger.debug("VM %s: extending site information.", vm.name) - vm.site = convert_recordset(netbox.dcim.sites.filter(id=nb_vm.site.id)) + vm.site = convert_recordset(netbox.dcim.sites.filter(id=nb_vm.site.id)) # type: ignore[attr-defined] vm.set_inventory(nb_vm) vm.set_usermacros() vm.set_tags() @@ -196,14 +194,14 @@ def main(arguments): # Add hostgroup if config is set if config["create_hostgroups"]: # Create new hostgroup. Potentially multiple groups if nested - hostgroups = vm.createZabbixHostgroup(zabbix_groups) + hostgroups = vm.create_zbx_hostgroup(zabbix_groups) # go through all newly created hostgroups for group in hostgroups: # Add new hostgroups to zabbix group list zabbix_groups.append(group) # Check if VM is already in Zabbix if vm.zabbix_id: - vm.ConsistencyCheck( + vm.consistency_check( zabbix_groups, zabbix_templates, zabbix_proxy_list, @@ -212,7 +210,7 @@ def main(arguments): ) continue # Add VM to Zabbix - vm.createInZabbix(zabbix_groups, zabbix_templates, zabbix_proxy_list) + vm.create_in_zabbix(zabbix_groups, zabbix_templates, zabbix_proxy_list) except SyncError: pass @@ -247,7 +245,7 @@ def main(arguments): continue if config["extended_site_properties"] and nb_device.site: logger.debug("Device %s: extending site information.", device.name) - device.site = convert_recordset( + device.site = convert_recordset( # type: ignore[attr-defined] netbox.dcim.sites.filter(id=nb_device.site.id) ) device.set_inventory(nb_device) @@ -255,9 +253,9 @@ def main(arguments): device.set_tags() # Checks if device is part of cluster. # Requires clustering variable - if device.isCluster() and config["clustering"]: + if device.is_cluster() and config["clustering"]: # Check if device is primary or secondary - if device.promoteMasterDevice(): + if device.promote_primary_device(): logger.info( "Device %s: is part of cluster and primary.", device.name ) @@ -290,14 +288,14 @@ def main(arguments): # Add hostgroup is config is set if config["create_hostgroups"]: # Create new hostgroup. Potentially multiple groups if nested - hostgroups = device.createZabbixHostgroup(zabbix_groups) + hostgroups = device.create_zbx_hostgroup(zabbix_groups) # go through all newly created hostgroups for group in hostgroups: # Add new hostgroups to zabbix group list zabbix_groups.append(group) # Check if device is already in Zabbix if device.zabbix_id: - device.ConsistencyCheck( + device.consistency_check( zabbix_groups, zabbix_templates, zabbix_proxy_list, @@ -306,7 +304,7 @@ def main(arguments): ) continue # Add device to Zabbix - device.createInZabbix(zabbix_groups, zabbix_templates, zabbix_proxy_list) + device.create_in_zabbix(zabbix_groups, zabbix_templates, zabbix_proxy_list) except SyncError: pass zabbix.logout() diff --git a/pyproject.toml b/pyproject.toml index 0ede4ed..5d64a22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ ignore = [ "PLR0915", # Ignore too many branches "PLR0912", + # Ignore use of assert + "S101", ] select = [ diff --git a/tests/test_configuration_parsing.py b/tests/test_configuration_parsing.py index 641b508..d6186e9 100644 --- a/tests/test_configuration_parsing.py +++ b/tests/test_configuration_parsing.py @@ -1,13 +1,22 @@ """Tests for configuration parsing in the modules.config module.""" -from unittest.mock import patch, MagicMock + import os -from modules.config import load_config, DEFAULT_CONFIG, load_config_file, load_env_variable +from unittest.mock import MagicMock, patch + +from modules.config import ( + DEFAULT_CONFIG, + load_config, + load_config_file, + load_env_variable, +) def test_load_config_defaults(): """Test that load_config returns default values when no config file or env vars are present""" - with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ - patch('modules.config.load_env_variable', return_value=None): + with ( + patch("modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy()), + patch("modules.config.load_env_variable", return_value=None), + ): config = load_config() assert config == DEFAULT_CONFIG assert config["templates_config_context"] is False @@ -20,8 +29,10 @@ def test_load_config_file(): mock_config["templates_config_context"] = True mock_config["sync_vms"] = True - with patch('modules.config.load_config_file', return_value=mock_config), \ - patch('modules.config.load_env_variable', return_value=None): + with ( + patch("modules.config.load_config_file", return_value=mock_config), + patch("modules.config.load_env_variable", return_value=None), + ): config = load_config() assert config["templates_config_context"] is True assert config["sync_vms"] is True @@ -31,6 +42,7 @@ def test_load_config_file(): def test_load_env_variables(): """Test that load_config properly loads values from environment variables""" + # Mock env variable loading to return values for specific keys def mock_load_env(key): if key == "sync_vms": @@ -39,8 +51,10 @@ def test_load_env_variables(): return True return None - with patch('modules.config.load_config_file', return_value=DEFAULT_CONFIG.copy()), \ - patch('modules.config.load_env_variable', side_effect=mock_load_env): + with ( + patch("modules.config.load_config_file", return_value=DEFAULT_CONFIG.copy()), + patch("modules.config.load_env_variable", side_effect=mock_load_env), + ): config = load_config() assert config["sync_vms"] is True assert config["create_journal"] is True @@ -60,8 +74,10 @@ def test_env_vars_override_config_file(): return True return None - with patch('modules.config.load_config_file', return_value=mock_config), \ - patch('modules.config.load_env_variable', side_effect=mock_load_env): + with ( + patch("modules.config.load_config_file", return_value=mock_config), + patch("modules.config.load_env_variable", side_effect=mock_load_env), + ): config = load_config() # This should be overridden by the env var assert config["sync_vms"] is True @@ -72,8 +88,10 @@ def test_env_vars_override_config_file(): def test_load_config_file_function(): """Test the load_config_file function directly""" # Test when the file exists - with patch('pathlib.Path.exists', return_value=True), \ - patch('importlib.util.spec_from_file_location') as mock_spec: + with ( + patch("pathlib.Path.exists", return_value=True), + patch("importlib.util.spec_from_file_location") as mock_spec, + ): # Setup the mock module with attributes mock_module = MagicMock() mock_module.templates_config_context = True @@ -85,7 +103,7 @@ def test_load_config_file_function(): mock_spec_instance.loader.exec_module = lambda x: None # Patch module_from_spec to return our mock module - with patch('importlib.util.module_from_spec', return_value=mock_module): + with patch("importlib.util.module_from_spec", return_value=mock_module): config = load_config_file(DEFAULT_CONFIG.copy()) assert config["templates_config_context"] is True assert config["sync_vms"] is True @@ -93,7 +111,7 @@ def test_load_config_file_function(): def test_load_config_file_not_found(): """Test load_config_file when the config file doesn't exist""" - with patch('pathlib.Path.exists', return_value=False): + with patch("pathlib.Path.exists", return_value=False): result = load_config_file(DEFAULT_CONFIG.copy()) # Should return a dict equal to DEFAULT_CONFIG, not a new object assert result == DEFAULT_CONFIG @@ -121,19 +139,3 @@ def test_load_env_variable_function(): os.environ[test_var] = original_env else: os.environ.pop(test_var, None) - - -def test_load_config_file_exception_handling(): - """Test that load_config_file handles exceptions gracefully""" - # This test requires modifying the load_config_file function to handle exceptions - # For now, we're just checking that an exception is raised - with patch('pathlib.Path.exists', return_value=True), \ - patch('importlib.util.spec_from_file_location', side_effect=Exception("Import error")): - # Since the current implementation doesn't handle exceptions, we should - # expect an exception to be raised - try: - load_config_file(DEFAULT_CONFIG.copy()) - assert False, "An exception should have been raised" - except Exception: # pylint: disable=broad-except - # This is expected - pass diff --git a/tests/test_device_deletion.py b/tests/test_device_deletion.py index 392ba1a..b26ca9b 100644 --- a/tests/test_device_deletion.py +++ b/tests/test_device_deletion.py @@ -1,7 +1,10 @@ """Tests for device deletion functionality in the PhysicalDevice class.""" + import unittest from unittest.mock import MagicMock, patch + from zabbix_utils import APIRequestError + from modules.device import PhysicalDevice from modules.exceptions import SyncExternalError @@ -38,14 +41,14 @@ class TestDeviceDeletion(unittest.TestCase): self.mock_logger = MagicMock() # Create PhysicalDevice instance with mocks - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): self.device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=True, - logger=self.mock_logger + logger=self.mock_logger, ) def test_cleanup_successful_deletion(self): @@ -58,12 +61,15 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) - self.mock_zabbix.host.delete.assert_called_once_with('456') + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) + self.mock_zabbix.host.delete.assert_called_once_with("456") self.mock_nb_device.save.assert_called_once() self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) - self.mock_logger.info.assert_called_with(f"Host {self.device.name}: " - "Deleted host from Zabbix.") + self.mock_logger.info.assert_called_with( + f"Host {self.device.name}: Deleted host from Zabbix." + ) def test_cleanup_device_already_deleted(self): """Test cleanup when device is already deleted from Zabbix.""" @@ -74,12 +80,15 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) self.mock_zabbix.host.delete.assert_not_called() self.mock_nb_device.save.assert_called_once() self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) self.mock_logger.info.assert_called_with( - f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox.") + f"Host {self.device.name}: was already deleted from Zabbix. Removed link in NetBox." + ) def test_cleanup_api_error(self): """Test cleanup when Zabbix API returns an error.""" @@ -92,15 +101,17 @@ class TestDeviceDeletion(unittest.TestCase): self.device.cleanup() # Verify correct calls were made - self.mock_zabbix.host.get.assert_called_once_with(filter={'hostid': '456'}, output=[]) - self.mock_zabbix.host.delete.assert_called_once_with('456') + self.mock_zabbix.host.get.assert_called_once_with( + filter={"hostid": "456"}, output=[] + ) + self.mock_zabbix.host.delete.assert_called_once_with("456") self.mock_nb_device.save.assert_not_called() self.mock_logger.error.assert_called() def test_zeroize_cf(self): """Test _zeroize_cf method that clears the custom field.""" # Execute - self.device._zeroize_cf() # pylint: disable=protected-access + self.device._zeroize_cf() # pylint: disable=protected-access # Verify self.assertIsNone(self.mock_nb_device.custom_fields["zabbix_hostid"]) @@ -136,14 +147,14 @@ class TestDeviceDeletion(unittest.TestCase): def test_create_journal_entry_when_disabled(self): """Test create_journal_entry when journaling is disabled.""" # Setup - create device with journal=False - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=False, # Disable journaling - logger=self.mock_logger + logger=self.mock_logger, ) # Execute @@ -159,8 +170,10 @@ class TestDeviceDeletion(unittest.TestCase): self.mock_zabbix.host.get.return_value = [{"hostid": "456"}] # Execute - with patch.object(self.device, 'create_journal_entry') as mock_journal_entry: + with patch.object(self.device, "create_journal_entry") as mock_journal_entry: self.device.cleanup() # Verify - mock_journal_entry.assert_called_once_with("warning", "Deleted host from Zabbix") + mock_journal_entry.assert_called_once_with( + "warning", "Deleted host from Zabbix" + ) diff --git a/tests/test_hostgroups.py b/tests/test_hostgroups.py index 995d26c..d9c0a69 100644 --- a/tests/test_hostgroups.py +++ b/tests/test_hostgroups.py @@ -1,8 +1,10 @@ """Tests for the Hostgroup class in the hostgroups module.""" + import unittest -from unittest.mock import MagicMock, patch, call -from modules.hostgroups import Hostgroup +from unittest.mock import MagicMock, patch + from modules.exceptions import HostgroupError +from modules.hostgroups import Hostgroup class TestHostgroups(unittest.TestCase): @@ -17,27 +19,27 @@ class TestHostgroups(unittest.TestCase): # Create mock device with all properties self.mock_device = MagicMock() self.mock_device.name = "test-device" - + # Set up site information site = MagicMock() site.name = "TestSite" - + # Set up region information region = MagicMock() region.name = "TestRegion" # Ensure region string representation returns the name region.__str__.return_value = "TestRegion" site.region = region - + # Set up site group information site_group = MagicMock() site_group.name = "TestSiteGroup" # Ensure site group string representation returns the name site_group.__str__.return_value = "TestSiteGroup" site.group = site_group - + self.mock_device.site = site - + # Set up role information (varies based on NetBox version) self.mock_device_role = MagicMock() self.mock_device_role.name = "TestRole" @@ -45,7 +47,7 @@ class TestHostgroups(unittest.TestCase): self.mock_device_role.__str__.return_value = "TestRole" self.mock_device.device_role = self.mock_device_role self.mock_device.role = self.mock_device_role - + # Set up tenant information tenant = MagicMock() tenant.name = "TestTenant" @@ -57,45 +59,45 @@ class TestHostgroups(unittest.TestCase): tenant_group.__str__.return_value = "TestTenantGroup" tenant.group = tenant_group self.mock_device.tenant = tenant - + # Set up platform information platform = MagicMock() platform.name = "TestPlatform" self.mock_device.platform = platform - + # Device-specific properties device_type = MagicMock() manufacturer = MagicMock() manufacturer.name = "TestManufacturer" device_type.manufacturer = manufacturer self.mock_device.device_type = device_type - + location = MagicMock() location.name = "TestLocation" # Ensure location string representation returns the name location.__str__.return_value = "TestLocation" self.mock_device.location = location - + # Custom fields self.mock_device.custom_fields = {"test_cf": "TestCF"} - + # *** Mock NetBox VM setup *** # Create mock VM with all properties self.mock_vm = MagicMock() self.mock_vm.name = "test-vm" - + # Reuse site from device self.mock_vm.site = site - + # Set up role for VM self.mock_vm.role = self.mock_device_role - + # Set up tenant for VM (same as device) self.mock_vm.tenant = tenant - + # Set up platform for VM (same as device) self.mock_vm.platform = platform - + # VM-specific properties cluster = MagicMock() cluster.name = "TestCluster" @@ -103,28 +105,28 @@ class TestHostgroups(unittest.TestCase): cluster_type.name = "TestClusterType" cluster.type = cluster_type self.mock_vm.cluster = cluster - + # Custom fields self.mock_vm.custom_fields = {"test_cf": "TestCF"} - + # Mock data for nesting tests self.mock_regions_data = [ {"name": "ParentRegion", "parent": None, "_depth": 0}, - {"name": "TestRegion", "parent": "ParentRegion", "_depth": 1} + {"name": "TestRegion", "parent": "ParentRegion", "_depth": 1}, ] - + self.mock_groups_data = [ {"name": "ParentSiteGroup", "parent": None, "_depth": 0}, - {"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1} + {"name": "TestSiteGroup", "parent": "ParentSiteGroup", "_depth": 1}, ] def test_device_hostgroup_creation(self): """Test basic device hostgroup creation.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Test the string representation self.assertEqual(str(hostgroup), "Hostgroup for dev test-device") - + # Check format options were set correctly self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["region"], "TestRegion") @@ -135,14 +137,14 @@ class TestHostgroups(unittest.TestCase): self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["manufacturer"], "TestManufacturer") self.assertEqual(hostgroup.format_options["location"], "TestLocation") - + def test_vm_hostgroup_creation(self): """Test basic VM hostgroup creation.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Test the string representation self.assertEqual(str(hostgroup), "Hostgroup for vm test-vm") - + # Check format options were set correctly self.assertEqual(hostgroup.format_options["site"], "TestSite") self.assertEqual(hostgroup.format_options["region"], "TestRegion") @@ -153,74 +155,76 @@ class TestHostgroups(unittest.TestCase): self.assertEqual(hostgroup.format_options["platform"], "TestPlatform") self.assertEqual(hostgroup.format_options["cluster"], "TestCluster") self.assertEqual(hostgroup.format_options["cluster_type"], "TestClusterType") - + def test_invalid_object_type(self): """Test that an invalid object type raises an exception.""" with self.assertRaises(HostgroupError): Hostgroup("invalid", self.mock_device, "4.0", self.mock_logger) - + def test_device_hostgroup_formats(self): """Test different hostgroup formats for devices.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Custom format: site/region custom_result = hostgroup.generate("site/region") self.assertEqual(custom_result, "TestSite/TestRegion") - + # Custom format: site/tenant/platform/location complex_result = hostgroup.generate("site/tenant/platform/location") - self.assertEqual(complex_result, "TestSite/TestTenant/TestPlatform/TestLocation") - + self.assertEqual( + complex_result, "TestSite/TestTenant/TestPlatform/TestLocation" + ) + def test_vm_hostgroup_formats(self): """Test different hostgroup formats for VMs.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Default format: cluster/role default_result = hostgroup.generate("cluster/role") self.assertEqual(default_result, "TestCluster/TestRole") - + # Custom format: site/tenant custom_result = hostgroup.generate("site/tenant") self.assertEqual(custom_result, "TestSite/TestTenant") - + # Custom format: cluster/cluster_type/platform complex_result = hostgroup.generate("cluster/cluster_type/platform") self.assertEqual(complex_result, "TestCluster/TestClusterType/TestPlatform") - + def test_device_netbox_version_differences(self): """Test hostgroup generation with different NetBox versions.""" # NetBox v2.x hostgroup_v2 = Hostgroup("dev", self.mock_device, "2.11", self.mock_logger) self.assertEqual(hostgroup_v2.format_options["role"], "TestRole") - + # NetBox v3.x hostgroup_v3 = Hostgroup("dev", self.mock_device, "3.5", self.mock_logger) self.assertEqual(hostgroup_v3.format_options["role"], "TestRole") - + # NetBox v4.x (already tested in other methods) - + def test_custom_field_lookup(self): """Test custom field lookup functionality.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Test custom field exists and is populated cf_result = hostgroup.custom_field_lookup("test_cf") self.assertTrue(cf_result["result"]) self.assertEqual(cf_result["cf"], "TestCF") - + # Test custom field doesn't exist cf_result = hostgroup.custom_field_lookup("nonexistent_cf") self.assertFalse(cf_result["result"]) self.assertIsNone(cf_result["cf"]) - + def test_hostgroup_with_custom_field(self): """Test hostgroup generation including a custom field.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Generate with custom field included result = hostgroup.generate("site/test_cf/role") self.assertEqual(result, "TestSite/TestCF/TestRole") - + def test_missing_hostgroup_format_item(self): """Test handling of missing hostgroup format items.""" # Create a device with minimal attributes @@ -230,137 +234,123 @@ class TestHostgroups(unittest.TestCase): minimal_device.tenant = None minimal_device.platform = None minimal_device.custom_fields = {} - + # Create role role = MagicMock() role.name = "MinimalRole" minimal_device.role = role - + # Create device_type with manufacturer device_type = MagicMock() manufacturer = MagicMock() manufacturer.name = "MinimalManufacturer" device_type.manufacturer = manufacturer minimal_device.device_type = device_type - + # Create hostgroup hostgroup = Hostgroup("dev", minimal_device, "4.0", self.mock_logger) - + # Generate with default format result = hostgroup.generate("site/manufacturer/role") # Site is missing, so only manufacturer and role should be included self.assertEqual(result, "MinimalManufacturer/MinimalRole") - + # Test with invalid format with self.assertRaises(HostgroupError): hostgroup.generate("site/nonexistent/role") - + def test_nested_region_hostgroups(self): """Test hostgroup generation with nested regions.""" # Mock the build_path function to return a predictable result - with patch('modules.hostgroups.build_path') as mock_build_path: + with patch("modules.hostgroups.build_path") as mock_build_path: # Configure the mock to return a list of regions in the path mock_build_path.return_value = ["ParentRegion", "TestRegion"] - + # Create hostgroup with nested regions enabled hostgroup = Hostgroup( - "dev", - self.mock_device, - "4.0", + "dev", + self.mock_device, + "4.0", self.mock_logger, nested_region_flag=True, - nb_regions=self.mock_regions_data + nb_regions=self.mock_regions_data, ) - + # Generate hostgroup with region result = hostgroup.generate("site/region/role") # Should include the parent region self.assertEqual(result, "TestSite/ParentRegion/TestRegion/TestRole") - + def test_nested_sitegroup_hostgroups(self): """Test hostgroup generation with nested site groups.""" # Mock the build_path function to return a predictable result - with patch('modules.hostgroups.build_path') as mock_build_path: + with patch("modules.hostgroups.build_path") as mock_build_path: # Configure the mock to return a list of site groups in the path mock_build_path.return_value = ["ParentSiteGroup", "TestSiteGroup"] - + # Create hostgroup with nested site groups enabled hostgroup = Hostgroup( - "dev", - self.mock_device, - "4.0", + "dev", + self.mock_device, + "4.0", self.mock_logger, nested_sitegroup_flag=True, - nb_groups=self.mock_groups_data + nb_groups=self.mock_groups_data, ) - + # Generate hostgroup with site_group result = hostgroup.generate("site/site_group/role") # Should include the parent site group self.assertEqual(result, "TestSite/ParentSiteGroup/TestSiteGroup/TestRole") - - def test_list_formatoptions(self): - """Test the list_formatoptions method for debugging.""" - hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - - # Patch sys.stdout to capture print output - with patch('sys.stdout') as mock_stdout: - hostgroup.list_formatoptions() - - # Check that print was called with expected output - calls = [call.write(f"The following options are available for host test-device"), - call.write('\n')] - mock_stdout.assert_has_calls(calls, any_order=True) - def test_vm_list_based_hostgroup_format(self): """Test VM hostgroup generation with a list-based format.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Test with a list of format strings format_list = ["platform", "role", "cluster_type/cluster"] - + # Generate hostgroups for each format in the list hostgroups = [] for fmt in format_list: result = hostgroup.generate(fmt) if result: # Only add non-None results hostgroups.append(result) - + # Verify each expected hostgroup is generated self.assertEqual(len(hostgroups), 3) # Should have 3 hostgroups self.assertIn("TestPlatform", hostgroups) self.assertIn("TestRole", hostgroups) self.assertIn("TestClusterType/TestCluster", hostgroups) - + def test_nested_format_splitting(self): """Test that formats with slashes correctly split and resolve each component.""" hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Test a format with slashes that should be split complex_format = "cluster_type/cluster" result = hostgroup.generate(complex_format) - + # Verify the format is correctly split and each component resolved self.assertEqual(result, "TestClusterType/TestCluster") - + def test_multiple_hostgroup_formats_device(self): """Test device hostgroup generation with multiple formats.""" hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Test with various formats that would be in a list formats = [ - "site", - "manufacturer/role", - "platform/location", - "tenant_group/tenant" + "site", + "manufacturer/role", + "platform/location", + "tenant_group/tenant", ] - + # Generate and check each format results = {} for fmt in formats: results[fmt] = hostgroup.generate(fmt) - + # Verify results self.assertEqual(results["site"], "TestSite") self.assertEqual(results["manufacturer/role"], "TestManufacturer/TestRole") diff --git a/tests/test_interface.py b/tests/test_interface.py index ff55218..3c37413 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -1,7 +1,10 @@ """Tests for the ZabbixInterface class in the interface module.""" + import unittest -from modules.interface import ZabbixInterface +from typing import cast + from modules.exceptions import InterfaceConfigError +from modules.interface import ZabbixInterface class TestZabbixInterface(unittest.TestCase): @@ -18,11 +21,7 @@ class TestZabbixInterface(unittest.TestCase): "zabbix": { "interface_type": 2, "interface_port": "161", - "snmp": { - "version": 2, - "community": "public", - "bulk": 1 - } + "snmp": {"version": 2, "community": "public", "bulk": 1}, } } @@ -37,16 +36,13 @@ class TestZabbixInterface(unittest.TestCase): "authpassphrase": "authpass123", "privprotocol": "AES", "privpassphrase": "privpass123", - "contextname": "context1" - } + "contextname": "context1", + }, } } self.agent_context = { - "zabbix": { - "interface_type": 1, - "interface_port": "10050" - } + "zabbix": {"interface_type": 1, "interface_port": "10050"} } def test_init(self): @@ -95,27 +91,27 @@ class TestZabbixInterface(unittest.TestCase): # Test for agent type (1) interface.interface["type"] = 1 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "10050") # Test for SNMP type (2) interface.interface["type"] = 2 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "161") # Test for IPMI type (3) interface.interface["type"] = 3 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "623") # Test for JMX type (4) interface.interface["type"] = 4 - interface._set_default_port() # pylint: disable=protected-access + interface._set_default_port() # pylint: disable=protected-access self.assertEqual(interface.interface["port"], "12345") # Test for unsupported type interface.interface["type"] = 99 - result = interface._set_default_port() # pylint: disable=protected-access + result = interface._set_default_port() # pylint: disable=protected-access self.assertFalse(result) def test_set_snmp_v2(self): @@ -127,9 +123,10 @@ class TestZabbixInterface(unittest.TestCase): interface.set_snmp() # Check SNMP details - self.assertEqual(interface.interface["details"]["version"], "2") - self.assertEqual(interface.interface["details"]["community"], "public") - self.assertEqual(interface.interface["details"]["bulk"], "1") + details = cast(dict[str, str], interface.interface["details"]) + self.assertEqual(details["version"], "2") + self.assertEqual(details["community"], "public") + self.assertEqual(details["bulk"], "1") def test_set_snmp_v3(self): """Test set_snmp with SNMPv3 configuration.""" @@ -140,14 +137,15 @@ class TestZabbixInterface(unittest.TestCase): interface.set_snmp() # Check SNMP details - self.assertEqual(interface.interface["details"]["version"], "3") - self.assertEqual(interface.interface["details"]["securityname"], "snmpuser") - self.assertEqual(interface.interface["details"]["securitylevel"], "authPriv") - self.assertEqual(interface.interface["details"]["authprotocol"], "SHA") - self.assertEqual(interface.interface["details"]["authpassphrase"], "authpass123") - self.assertEqual(interface.interface["details"]["privprotocol"], "AES") - self.assertEqual(interface.interface["details"]["privpassphrase"], "privpass123") - self.assertEqual(interface.interface["details"]["contextname"], "context1") + details = cast(dict[str, str], interface.interface["details"]) + self.assertEqual(details["version"], "3") + self.assertEqual(details["securityname"], "snmpuser") + self.assertEqual(details["securitylevel"], "authPriv") + self.assertEqual(details["authprotocol"], "SHA") + self.assertEqual(details["authpassphrase"], "authpass123") + self.assertEqual(details["privprotocol"], "AES") + self.assertEqual(details["privpassphrase"], "privpass123") + self.assertEqual(details["contextname"], "context1") def test_set_snmp_no_snmp_config(self): """Test set_snmp with missing SNMP configuration.""" @@ -168,7 +166,7 @@ class TestZabbixInterface(unittest.TestCase): "interface_type": 2, "snmp": { "version": 4 # Invalid version - } + }, } } interface = ZabbixInterface(context, self.test_ip) @@ -186,7 +184,7 @@ class TestZabbixInterface(unittest.TestCase): "interface_type": 2, "snmp": { "community": "public" # No version specified - } + }, } } interface = ZabbixInterface(context, self.test_ip) @@ -213,9 +211,10 @@ class TestZabbixInterface(unittest.TestCase): # Check interface properties self.assertEqual(interface.interface["type"], "2") self.assertEqual(interface.interface["port"], "161") - self.assertEqual(interface.interface["details"]["version"], "2") - self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") - self.assertEqual(interface.interface["details"]["bulk"], "1") + details = cast(dict[str, str], interface.interface["details"]) + self.assertEqual(details["version"], "2") + self.assertEqual(details["community"], "{$SNMP_COMMUNITY}") + self.assertEqual(details["bulk"], "1") def test_set_default_agent(self): """Test set_default_agent method.""" @@ -229,14 +228,7 @@ class TestZabbixInterface(unittest.TestCase): def test_snmpv2_no_community(self): """Test SNMPv2 with no community string specified.""" # Create context with SNMPv2 but no community - context = { - "zabbix": { - "interface_type": 2, - "snmp": { - "version": 2 - } - } - } + context = {"zabbix": {"interface_type": 2, "snmp": {"version": 2}}} interface = ZabbixInterface(context, self.test_ip) interface.get_context() # Set the interface type @@ -244,4 +236,5 @@ class TestZabbixInterface(unittest.TestCase): interface.set_snmp() # Should use default community string - self.assertEqual(interface.interface["details"]["community"], "{$SNMP_COMMUNITY}") + details = cast(dict[str, str], interface.interface["details"]) + self.assertEqual(details["community"], "{$SNMP_COMMUNITY}") diff --git a/tests/test_list_hostgroup_formats.py b/tests/test_list_hostgroup_formats.py index 9b8cc21..aeaa181 100644 --- a/tests/test_list_hostgroup_formats.py +++ b/tests/test_list_hostgroup_formats.py @@ -1,8 +1,10 @@ """Tests for list-based hostgroup formats in configuration.""" + import unittest -from unittest.mock import MagicMock, patch -from modules.hostgroups import Hostgroup +from unittest.mock import MagicMock + from modules.exceptions import HostgroupError +from modules.hostgroups import Hostgroup from modules.tools import verify_hg_format @@ -17,56 +19,56 @@ class TestListHostgroupFormats(unittest.TestCase): # Create mock device self.mock_device = MagicMock() self.mock_device.name = "test-device" - + # Set up site information site = MagicMock() site.name = "TestSite" - + # Set up region information region = MagicMock() region.name = "TestRegion" region.__str__.return_value = "TestRegion" site.region = region - + # Set device site self.mock_device.site = site - + # Set up role information self.mock_device_role = MagicMock() self.mock_device_role.name = "TestRole" - self.mock_device_role.__str__.return_value = "TestRole" + self.mock_device_role.__str__.return_value = "TestRole" self.mock_device.role = self.mock_device_role - + # Set up rack information rack = MagicMock() rack.name = "TestRack" self.mock_device.rack = rack - + # Set up platform information platform = MagicMock() platform.name = "TestPlatform" self.mock_device.platform = platform - + # Device-specific properties device_type = MagicMock() manufacturer = MagicMock() manufacturer.name = "TestManufacturer" device_type.manufacturer = manufacturer self.mock_device.device_type = device_type - + # Create mock VM self.mock_vm = MagicMock() self.mock_vm.name = "test-vm" - + # Reuse site from device self.mock_vm.site = site - + # Set up role for VM self.mock_vm.role = self.mock_device_role - + # Set up platform for VM self.mock_vm.platform = platform - + # VM-specific properties cluster = MagicMock() cluster.name = "TestCluster" @@ -79,53 +81,53 @@ class TestListHostgroupFormats(unittest.TestCase): """Test verification of list-based hostgroup formats.""" # List format with valid items valid_format = ["region", "site", "rack"] - + # List format with nested path valid_nested_format = ["region", "site/rack"] - + # List format with invalid item invalid_format = ["region", "invalid_item", "rack"] - + # Should not raise exception for valid formats verify_hg_format(valid_format, hg_type="dev", logger=self.mock_logger) verify_hg_format(valid_nested_format, hg_type="dev", logger=self.mock_logger) - + # Should raise exception for invalid format with self.assertRaises(HostgroupError): verify_hg_format(invalid_format, hg_type="dev", logger=self.mock_logger) - + def test_simulate_hostgroup_generation_from_config(self): """Simulate how the main script would generate hostgroups from list-based config.""" # Mock configuration with list-based hostgroup format config_format = ["region", "site", "rack"] hostgroup = Hostgroup("dev", self.mock_device, "4.0", self.mock_logger) - + # Simulate the main script's hostgroup generation process hostgroups = [] for fmt in config_format: result = hostgroup.generate(fmt) if result: hostgroups.append(result) - + # Check results self.assertEqual(len(hostgroups), 3) self.assertIn("TestRegion", hostgroups) self.assertIn("TestSite", hostgroups) self.assertIn("TestRack", hostgroups) - + def test_vm_hostgroup_format_from_config(self): """Test VM hostgroup generation with list-based format.""" # Mock VM configuration with mixed format config_format = ["platform", "role", "cluster_type/cluster"] hostgroup = Hostgroup("vm", self.mock_vm, "4.0", self.mock_logger) - + # Simulate the main script's hostgroup generation process hostgroups = [] for fmt in config_format: result = hostgroup.generate(fmt) if result: hostgroups.append(result) - + # Check results self.assertEqual(len(hostgroups), 3) self.assertIn("TestPlatform", hostgroups) diff --git a/tests/test_physical_device.py b/tests/test_physical_device.py index 1b79ad8..2c9843f 100644 --- a/tests/test_physical_device.py +++ b/tests/test_physical_device.py @@ -1,8 +1,10 @@ """Tests for the PhysicalDevice class in the device module.""" + import unittest from unittest.mock import MagicMock, patch + from modules.device import PhysicalDevice -from modules.exceptions import TemplateError, SyncInventoryError +from modules.exceptions import TemplateError class TestPhysicalDevice(unittest.TestCase): @@ -34,24 +36,27 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_logger = MagicMock() # Create PhysicalDevice instance with mocks - with patch('modules.device.config', - {"device_cf": "zabbix_hostid", - "template_cf": "zabbix_template", - "templates_config_context": False, - "templates_config_context_overrule": False, - "traverse_regions": False, - "traverse_site_groups": False, - "inventory_mode": "disabled", - "inventory_sync": False, - "device_inventory_map": {} - }): + with patch( + "modules.device.config", + { + "device_cf": "zabbix_hostid", + "template_cf": "zabbix_template", + "templates_config_context": False, + "templates_config_context_overrule": False, + "traverse_regions": False, + "traverse_site_groups": False, + "inventory_mode": "disabled", + "inventory_sync": False, + "device_inventory_map": {}, + }, + ): self.device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", journal=True, - logger=self.mock_logger + logger=self.mock_logger, ) def test_init(self): @@ -63,22 +68,6 @@ class TestPhysicalDevice(unittest.TestCase): self.assertEqual(self.device.ip, "192.168.1.1") self.assertEqual(self.device.cidr, "192.168.1.1/24") - def test_init_no_primary_ip(self): - """Test initialization when device has no primary IP.""" - # Set primary_ip to None - self.mock_nb_device.primary_ip = None - - # Creating device should raise SyncInventoryError - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): - with self.assertRaises(SyncInventoryError): - PhysicalDevice( - self.mock_nb_device, - self.mock_zabbix, - self.mock_nb_journal, - "3.0", - logger=self.mock_logger - ) - def test_set_basics_with_special_characters(self): """Test _setBasics when device name contains special characters.""" # Set name with special characters that @@ -86,8 +75,10 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.name = "test-devïce" # We need to patch the search function to simulate finding special characters - with patch('modules.device.search') as mock_search, \ - patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with ( + patch("modules.device.search") as mock_search, + patch("modules.device.config", {"device_cf": "zabbix_hostid"}), + ): # Make the search function return True to simulate special characters mock_search.return_value = True @@ -96,7 +87,7 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # With the mocked search function, the name should be changed to NETBOX_ID format @@ -110,19 +101,17 @@ class TestPhysicalDevice(unittest.TestCase): """Test get_templates_context with valid config.""" # Set up config_context with valid template data self.mock_nb_device.config_context = { - "zabbix": { - "templates": ["Template1", "Template2"] - } + "zabbix": {"templates": ["Template1", "Template2"]} } # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that templates are returned correctly @@ -132,20 +121,16 @@ class TestPhysicalDevice(unittest.TestCase): def test_get_templates_context_with_string(self): """Test get_templates_context with a string instead of list.""" # Set up config_context with a string template - self.mock_nb_device.config_context = { - "zabbix": { - "templates": "Template1" - } - } + self.mock_nb_device.config_context = {"zabbix": {"templates": "Template1"}} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that template is wrapped in a list @@ -158,13 +143,13 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.config_context = {} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that TemplateError is raised @@ -177,13 +162,13 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.config_context = {"zabbix": {}} # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Test that TemplateError is raised @@ -193,25 +178,25 @@ class TestPhysicalDevice(unittest.TestCase): def test_set_template_with_config_context(self): """Test set_template with templates_config_context=True.""" # Set up config_context with templates - self.mock_nb_device.config_context = { - "zabbix": { - "templates": ["Template1"] - } - } + self.mock_nb_device.config_context = {"zabbix": {"templates": ["Template1"]}} # Mock get_templates_context to return expected templates - with patch.object(PhysicalDevice, 'get_templates_context', return_value=["Template1"]): - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch.object( + PhysicalDevice, "get_templates_context", return_value=["Template1"] + ): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_template with prefer_config_context=True - result = device.set_template(prefer_config_context=True, overrule_custom=False) + result = device.set_template( + prefer_config_context=True, overrule_custom=False + ) # Check result and template names self.assertTrue(result) @@ -223,20 +208,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "disabled", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -250,20 +235,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "manual", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -276,20 +261,20 @@ class TestPhysicalDevice(unittest.TestCase): config_patch = { "device_cf": "zabbix_hostid", "inventory_mode": "automatic", - "inventory_sync": False + "inventory_sync": False, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory({}) # Check result @@ -303,38 +288,31 @@ class TestPhysicalDevice(unittest.TestCase): "device_cf": "zabbix_hostid", "inventory_mode": "manual", "inventory_sync": True, - "device_inventory_map": { - "name": "name", - "serial": "serialno_a" - } + "device_inventory_map": {"name": "name", "serial": "serialno_a"}, } - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Create a mock device with the required attributes - mock_device_data = { - "name": "test-device", - "serial": "ABC123" - } + mock_device_data = {"name": "test-device", "serial": "ABC123"} # Call set_inventory with the config patch still active - with patch('modules.device.config', config_patch): + with patch("modules.device.config", config_patch): result = device.set_inventory(mock_device_data) # Check result self.assertTrue(result) self.assertEqual(device.inventory_mode, 0) # Manual mode - self.assertEqual(device.inventory, { - "name": "test-device", - "serialno_a": "ABC123" - }) + self.assertEqual( + device.inventory, {"name": "test-device", "serialno_a": "ABC123"} + ) def test_iscluster_true(self): """Test isCluster when device is part of a cluster.""" @@ -342,17 +320,17 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.virtual_chassis = MagicMock() # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Check isCluster result - self.assertTrue(device.isCluster()) + self.assertTrue(device.is_cluster()) def test_is_cluster_false(self): """Test isCluster when device is not part of a cluster.""" @@ -360,18 +338,17 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_nb_device.virtual_chassis = None # Create device with the updated mock - with patch('modules.device.config', {"device_cf": "zabbix_hostid"}): + with patch("modules.device.config", {"device_cf": "zabbix_hostid"}): device = PhysicalDevice( self.mock_nb_device, self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Check isCluster result - self.assertFalse(device.isCluster()) - + self.assertFalse(device.is_cluster()) def test_promote_master_device_primary(self): """Test promoteMasterDevice when device is primary in cluster.""" @@ -379,7 +356,9 @@ class TestPhysicalDevice(unittest.TestCase): mock_vc = MagicMock() mock_vc.name = "virtual-chassis-1" mock_master = MagicMock() - mock_master.id = self.mock_nb_device.id # Set master ID to match the current device + mock_master.id = ( + self.mock_nb_device.id + ) # Set master ID to match the current device mock_vc.master = mock_master self.mock_nb_device.virtual_chassis = mock_vc @@ -389,25 +368,26 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call promoteMasterDevice and check the result - result = device.promoteMasterDevice() + result = device.promote_primary_device() # Should return True for primary device self.assertTrue(result) # Device name should be updated to virtual chassis name self.assertEqual(device.name, "virtual-chassis-1") - def test_promote_master_device_secondary(self): """Test promoteMasterDevice when device is secondary in cluster.""" # Set up virtual chassis with a different master device mock_vc = MagicMock() mock_vc.name = "virtual-chassis-1" mock_master = MagicMock() - mock_master.id = self.mock_nb_device.id + 1 # Different ID than the current device + mock_master.id = ( + self.mock_nb_device.id + 1 + ) # Different ID than the current device mock_vc.master = mock_master self.mock_nb_device.virtual_chassis = mock_vc @@ -417,11 +397,11 @@ class TestPhysicalDevice(unittest.TestCase): self.mock_zabbix, self.mock_nb_journal, "3.0", - logger=self.mock_logger + logger=self.mock_logger, ) # Call promoteMasterDevice and check the result - result = device.promoteMasterDevice() + result = device.promote_primary_device() # Should return False for secondary device self.assertFalse(result) diff --git a/tests/test_tools.py b/tests/test_tools.py index 3e6ae24..5361743 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,5 +1,6 @@ from modules.tools import sanatize_log_output + def test_sanatize_log_output_secrets(): data = { "macros": [ @@ -11,6 +12,7 @@ def test_sanatize_log_output_secrets(): assert sanitized["macros"][0]["value"] == "********" assert sanitized["macros"][1]["value"] == "notsecret" + def test_sanatize_log_output_interface_secrets(): data = { "interfaceid": 123, @@ -19,8 +21,8 @@ def test_sanatize_log_output_interface_secrets(): "privpassphrase": "anothersecret", "securityname": "sensitiveuser", "community": "public", - "other": "normalvalue" - } + "other": "normalvalue", + }, } sanitized = sanatize_log_output(data) # Sensitive fields should be sanitized @@ -33,6 +35,7 @@ def test_sanatize_log_output_interface_secrets(): # interfaceid should be removed assert "interfaceid" not in sanitized + def test_sanatize_log_output_interface_macros(): data = { "interfaceid": 123, @@ -41,7 +44,7 @@ def test_sanatize_log_output_interface_macros(): "privpassphrase": "{$SECRET_MACRO}", "securityname": "{$USER_MACRO}", "community": "{$SNNMP_COMMUNITY}", - } + }, } sanitized = sanatize_log_output(data) # Macro values should not be sanitized @@ -51,11 +54,13 @@ def test_sanatize_log_output_interface_macros(): assert sanitized["details"]["community"] == "{$SNNMP_COMMUNITY}" assert "interfaceid" not in sanitized + def test_sanatize_log_output_plain_data(): data = {"foo": "bar", "baz": 123} sanitized = sanatize_log_output(data) assert sanitized == data + def test_sanatize_log_output_non_dict(): data = [1, 2, 3] sanitized = sanatize_log_output(data) diff --git a/tests/test_usermacros.py b/tests/test_usermacros.py index 5c2b6a4..dbcfaed 100644 --- a/tests/test_usermacros.py +++ b/tests/test_usermacros.py @@ -1,8 +1,10 @@ import unittest from unittest.mock import MagicMock, patch + from modules.device import PhysicalDevice from modules.usermacros import ZabbixUsermacros + class DummyNB: def __init__(self, name="dummy", config_context=None, **kwargs): self.name = name @@ -18,46 +20,100 @@ class DummyNB: return self.config_context[key] raise KeyError(key) + class TestUsermacroSync(unittest.TestCase): def setUp(self): self.nb = DummyNB(serial="1234") self.logger = MagicMock() self.usermacro_map = {"serial": "{$HW_SERIAL}"} - @patch("modules.device.config", {"usermacro_sync": False}) - def test_usermacro_sync_false(self): - device = PhysicalDevice.__new__(PhysicalDevice) - device.nb = self.nb - device.logger = self.logger - device.name = "dummy" - device._usermacro_map = MagicMock(return_value=self.usermacro_map) - # call set_usermacros + def create_mock_device(self): + """Helper method to create a properly mocked PhysicalDevice""" + # Mock the NetBox device with all required attributes + mock_nb = MagicMock() + mock_nb.id = 1 + mock_nb.name = "dummy" + mock_nb.status.label = "Active" + mock_nb.tenant = None + mock_nb.config_context = {} + mock_nb.primary_ip.address = "192.168.1.1/24" + mock_nb.custom_fields = {"zabbix_hostid": None} + + # Create device with proper initialization + device = PhysicalDevice( + nb=mock_nb, + zabbix=MagicMock(), + nb_journal_class=MagicMock(), + nb_version="3.0", + logger=self.logger, + ) + + return device + + @patch( + "modules.device.config", + {"usermacro_sync": False, "device_cf": "zabbix_hostid", "tag_sync": False}, + ) + @patch.object(PhysicalDevice, "_usermacro_map") + def test_usermacro_sync_false(self, mock_usermacro_map): + mock_usermacro_map.return_value = self.usermacro_map + device = self.create_mock_device() + + # Call set_usermacros result = device.set_usermacros() + self.assertEqual(device.usermacros, []) self.assertTrue(result is True or result is None) - @patch("modules.device.config", {"usermacro_sync": True}) - def test_usermacro_sync_true(self): - device = PhysicalDevice.__new__(PhysicalDevice) - device.nb = self.nb - device.logger = self.logger - device.name = "dummy" - device._usermacro_map = MagicMock(return_value=self.usermacro_map) - result = device.set_usermacros() + @patch( + "modules.device.config", + {"usermacro_sync": True, "device_cf": "zabbix_hostid", "tag_sync": False}, + ) + @patch("modules.device.ZabbixUsermacros") + @patch.object(PhysicalDevice, "_usermacro_map") + def test_usermacro_sync_true(self, mock_usermacro_map, mock_usermacros_class): + mock_usermacro_map.return_value = self.usermacro_map + # Mock the ZabbixUsermacros class to return some test data + mock_macros_instance = MagicMock() + mock_macros_instance.sync = True # This is important - sync must be True + mock_macros_instance.generate.return_value = [ + {"macro": "{$HW_SERIAL}", "value": "1234"} + ] + mock_usermacros_class.return_value = mock_macros_instance + + device = self.create_mock_device() + + # Call set_usermacros + device.set_usermacros() + self.assertIsInstance(device.usermacros, list) self.assertGreater(len(device.usermacros), 0) - @patch("modules.device.config", {"usermacro_sync": "full"}) - def test_usermacro_sync_full(self): - device = PhysicalDevice.__new__(PhysicalDevice) - device.nb = self.nb - device.logger = self.logger - device.name = "dummy" - device._usermacro_map = MagicMock(return_value=self.usermacro_map) - result = device.set_usermacros() + @patch( + "modules.device.config", + {"usermacro_sync": "full", "device_cf": "zabbix_hostid", "tag_sync": False}, + ) + @patch("modules.device.ZabbixUsermacros") + @patch.object(PhysicalDevice, "_usermacro_map") + def test_usermacro_sync_full(self, mock_usermacro_map, mock_usermacros_class): + mock_usermacro_map.return_value = self.usermacro_map + # Mock the ZabbixUsermacros class to return some test data + mock_macros_instance = MagicMock() + mock_macros_instance.sync = True # This is important - sync must be True + mock_macros_instance.generate.return_value = [ + {"macro": "{$HW_SERIAL}", "value": "1234"} + ] + mock_usermacros_class.return_value = mock_macros_instance + + device = self.create_mock_device() + + # Call set_usermacros + device.set_usermacros() + self.assertIsInstance(device.usermacros, list) self.assertGreater(len(device.usermacros), 0) + class TestZabbixUsermacros(unittest.TestCase): def setUp(self): self.nb = DummyNB() @@ -78,7 +134,9 @@ class TestZabbixUsermacros(unittest.TestCase): def test_render_macro_dict(self): macros = ZabbixUsermacros(self.nb, {}, False, logger=self.logger) - macro = macros.render_macro("{$FOO}", {"value": "bar", "type": "secret", "description": "desc"}) + macro = macros.render_macro( + "{$FOO}", {"value": "bar", "type": "secret", "description": "desc"} + ) self.assertEqual(macro["macro"], "{$FOO}") self.assertEqual(macro["value"], "bar") self.assertEqual(macro["type"], "1") @@ -114,12 +172,14 @@ class TestZabbixUsermacros(unittest.TestCase): self.assertEqual(result[1]["macro"], "{$BAR}") def test_generate_from_config_context(self): - config_context = {"zabbix": {"usermacros": {"{$FOO}": {"value": "bar"}}}} + config_context = {"zabbix": {"usermacros": {"{$TEST_MACRO}": "test_value"}}} nb = DummyNB(config_context=config_context) macros = ZabbixUsermacros(nb, {}, True, logger=self.logger) result = macros.generate() self.assertEqual(len(result), 1) - self.assertEqual(result[0]["macro"], "{$FOO}") + self.assertEqual(result[0]["macro"], "{$TEST_MACRO}") + self.assertEqual(result[0]["value"], "test_value") + if __name__ == "__main__": unittest.main()