diff --git a/docs/additional-features/custom-scripts.md b/docs/additional-features/custom-scripts.md index c4dffb4b9..6fac5b63d 100644 --- a/docs/additional-features/custom-scripts.md +++ b/docs/additional-features/custom-scripts.md @@ -124,7 +124,7 @@ Arbitrary text of any length. Renders as multi-line text input field. Stored a numeric integer. Options include: -* `min_value:` - Minimum value +* `min_value` - Minimum value * `max_value` - Maximum value ### BooleanVar @@ -158,9 +158,20 @@ A NetBox object. The list of available objects is defined by the queryset parame An uploaded file. Note that uploaded files are present in memory only for the duration of the script's execution: They will not be save for future use. +### IPAddressVar + +An IPv4 or IPv6 address, without a mask. Returns a `netaddr.IPAddress` object. + +### IPAddressWithMaskVar + +An IPv4 or IPv6 address with a mask. Returns a `netaddr.IPNetwork` object which includes the mask. + ### IPNetworkVar -An IPv4 or IPv6 network with a mask. +An IPv4 or IPv6 network with a mask. Returns a `netaddr.IPNetwork` object. Two attributes are available to validate the provided mask: + +* `min_prefix_length` - Minimum length of the mask (default: none) +* `max_prefix_length` - Maximum length of the mask (default: none) ### Default Options diff --git a/docs/release-notes/version-2.7.md b/docs/release-notes/version-2.7.md index bc9f22cf9..997ab751e 100644 --- a/docs/release-notes/version-2.7.md +++ b/docs/release-notes/version-2.7.md @@ -3,6 +3,13 @@ ## Enhancements * [#568](https://github.com/netbox-community/netbox/issues/568) - Allow custom fields to be imported and exported using CSV +* [#3310](https://github.com/netbox-community/netbox/issues/3310) - Pre-select site/rack for B side when creating a new cable +* [#3509](https://github.com/netbox-community/netbox/issues/3509) - Add IP address variables for custom scripts + +## Bug Fixes + +* [#3983](https://github.com/netbox-community/netbox/issues/3983) - Permit the creation of multiple unnamed devices +* [#3989](https://github.com/netbox-community/netbox/issues/3989) - Correct HTTP content type assignment for webhooks --- diff --git a/netbox/dcim/forms.py b/netbox/dcim/forms.py index a5e8a782f..da4134eed 100644 --- a/netbox/dcim/forms.py +++ b/netbox/dcim/forms.py @@ -3168,6 +3168,11 @@ class ConnectCableToDeviceForm(BootstrapMixin, ChainedFieldsMixin, forms.ModelFo 'termination_b_site', 'termination_b_rack', 'termination_b_device', 'termination_b_id', 'type', 'status', 'label', 'color', 'length', 'length_unit', ] + widgets = { + 'status': StaticSelect2, + 'type': StaticSelect2, + 'length_unit': StaticSelect2, + } class ConnectCableToConsolePortForm(ConnectCableToDeviceForm): @@ -3363,6 +3368,11 @@ class CableForm(BootstrapMixin, forms.ModelForm): fields = [ 'type', 'status', 'label', 'color', 'length', 'length_unit', ] + widgets = { + 'status': StaticSelect2, + 'type': StaticSelect2, + 'length_unit': StaticSelect2, + } class CableCSVForm(forms.ModelForm): diff --git a/netbox/dcim/models/__init__.py b/netbox/dcim/models/__init__.py index d1b596c22..1c9c8682d 100644 --- a/netbox/dcim/models/__init__.py +++ b/netbox/dcim/models/__init__.py @@ -1445,10 +1445,11 @@ class Device(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): # Check for a duplicate name on a device assigned to the same Site and no Tenant. This is necessary # because Django does not consider two NULL fields to be equal, and thus will not trigger a violation # of the uniqueness constraint without manual intervention. - if self.tenant is None and Device.objects.exclude(pk=self.pk).filter(name=self.name, tenant__isnull=True): - raise ValidationError({ - 'name': 'A device with this name already exists.' - }) + if self.name and self.tenant is None: + if Device.objects.exclude(pk=self.pk).filter(name=self.name, tenant__isnull=True): + raise ValidationError({ + 'name': 'A device with this name already exists.' + }) super().validate_unique(exclude) diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index a515df13c..a3a072bc9 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -4,6 +4,7 @@ from netaddr import IPNetwork from rest_framework import status from circuits.models import Circuit, CircuitTermination, CircuitType, Provider +from dcim.api import serializers from dcim.choices import * from dcim.constants import * from dcim.models import ( @@ -595,6 +596,21 @@ class RackTest(APITestCase): self.assertEqual(response.data['count'], 42) + def test_get_rack_elevation(self): + + url = reverse('dcim-api:rack-elevation', kwargs={'pk': self.rack1.pk}) + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['count'], 42) + + def test_get_rack_elevation_svg(self): + + url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': self.rack1.pk})) + response = self.client.get(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(response.get('Content-Type'), 'image/svg+xml') + def test_list_racks(self): url = reverse('dcim-api:rack-list') @@ -1900,6 +1916,31 @@ class DeviceTest(APITestCase): self.assertEqual(response.data['device_role']['id'], self.devicerole1.pk) self.assertEqual(response.data['cluster']['id'], self.cluster1.pk) + def test_get_device_graphs(self): + + device_ct = ContentType.objects.get_for_model(Device) + self.graph1 = Graph.objects.create( + type=device_ct, + name='Test Graph 1', + source='http://example.com/graphs.py?device={{ obj.name }}&foo=1' + ) + self.graph2 = Graph.objects.create( + type=device_ct, + name='Test Graph 2', + source='http://example.com/graphs.py?device={{ obj.name }}&foo=2' + ) + self.graph3 = Graph.objects.create( + type=device_ct, + name='Test Graph 3', + source='http://example.com/graphs.py?device={{ obj.name }}&foo=3' + ) + + url = reverse('dcim-api:device-graphs', kwargs={'pk': self.device1.pk}) + response = self.client.get(url, **self.header) + + self.assertEqual(len(response.data), 3) + self.assertEqual(response.data[0]['embed_url'], 'http://example.com/graphs.py?device=Test Device 1&foo=1') + def test_list_devices(self): url = reverse('dcim-api:device-list') @@ -2134,6 +2175,31 @@ class ConsolePortTest(APITestCase): self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertEqual(ConsolePort.objects.count(), 2) + def test_trace_consoleport(self): + + peer_device = Device.objects.create( + site=Site.objects.first(), + device_type=DeviceType.objects.first(), + device_role=DeviceRole.objects.first(), + name='Peer Device' + ) + console_server_port = ConsoleServerPort.objects.create( + device=peer_device, + name='Console Server Port 1' + ) + cable = Cable(termination_a=self.consoleport1, termination_b=console_server_port, label='Cable 1') + cable.save() + + url = reverse('dcim-api:consoleport-trace', kwargs={'pk': self.consoleport1.pk}) + response = self.client.get(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + segment1 = response.data[0] + self.assertEqual(segment1[0]['name'], self.consoleport1.name) + self.assertEqual(segment1[1]['label'], cable.label) + self.assertEqual(segment1[2]['name'], console_server_port.name) + class ConsoleServerPortTest(APITestCase): @@ -2245,6 +2311,31 @@ class ConsoleServerPortTest(APITestCase): self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertEqual(ConsoleServerPort.objects.count(), 2) + def test_trace_consoleserverport(self): + + peer_device = Device.objects.create( + site=Site.objects.first(), + device_type=DeviceType.objects.first(), + device_role=DeviceRole.objects.first(), + name='Peer Device' + ) + console_port = ConsolePort.objects.create( + device=peer_device, + name='Console Port 1' + ) + cable = Cable(termination_a=self.consoleserverport1, termination_b=console_port, label='Cable 1') + cable.save() + + url = reverse('dcim-api:consoleserverport-trace', kwargs={'pk': self.consoleserverport1.pk}) + response = self.client.get(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + segment1 = response.data[0] + self.assertEqual(segment1[0]['name'], self.consoleserverport1.name) + self.assertEqual(segment1[1]['label'], cable.label) + self.assertEqual(segment1[2]['name'], console_port.name) + class PowerPortTest(APITestCase): @@ -2358,6 +2449,31 @@ class PowerPortTest(APITestCase): self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertEqual(PowerPort.objects.count(), 2) + def test_trace_powerport(self): + + peer_device = Device.objects.create( + site=Site.objects.first(), + device_type=DeviceType.objects.first(), + device_role=DeviceRole.objects.first(), + name='Peer Device' + ) + power_outlet = PowerOutlet.objects.create( + device=peer_device, + name='Power Outlet 1' + ) + cable = Cable(termination_a=self.powerport1, termination_b=power_outlet, label='Cable 1') + cable.save() + + url = reverse('dcim-api:powerport-trace', kwargs={'pk': self.powerport1.pk}) + response = self.client.get(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + segment1 = response.data[0] + self.assertEqual(segment1[0]['name'], self.powerport1.name) + self.assertEqual(segment1[1]['label'], cable.label) + self.assertEqual(segment1[2]['name'], power_outlet.name) + class PowerOutletTest(APITestCase): @@ -2469,6 +2585,31 @@ class PowerOutletTest(APITestCase): self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertEqual(PowerOutlet.objects.count(), 2) + def test_trace_poweroutlet(self): + + peer_device = Device.objects.create( + site=Site.objects.first(), + device_type=DeviceType.objects.first(), + device_role=DeviceRole.objects.first(), + name='Peer Device' + ) + power_port = PowerPort.objects.create( + device=peer_device, + name='Power Port 1' + ) + cable = Cable(termination_a=self.poweroutlet1, termination_b=power_port, label='Cable 1') + cable.save() + + url = reverse('dcim-api:poweroutlet-trace', kwargs={'pk': self.poweroutlet1.pk}) + response = self.client.get(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + segment1 = response.data[0] + self.assertEqual(segment1[0]['name'], self.poweroutlet1.name) + self.assertEqual(segment1[1]['label'], cable.label) + self.assertEqual(segment1[2]['name'], power_port.name) + class InterfaceTest(APITestCase): @@ -2673,6 +2814,262 @@ class InterfaceTest(APITestCase): self.assertEqual(Interface.objects.count(), 2) +class FrontPortTest(APITestCase): + + def setUp(self): + + super().setUp() + + site = Site.objects.create(name='Test Site 1', slug='test-site-1') + manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') + devicetype = DeviceType.objects.create( + manufacturer=manufacturer, model='Test Device Type 1', slug='test-device-type-1' + ) + devicerole = DeviceRole.objects.create( + name='Test Device Role 1', slug='test-device-role-1', color='ff0000' + ) + self.device = Device.objects.create( + device_type=devicetype, device_role=devicerole, name='Test Device 1', site=site + ) + rear_ports = RearPort.objects.bulk_create(( + RearPort(device=self.device, name='Rear Port 1', type=PortTypeChoices.TYPE_8P8C), + RearPort(device=self.device, name='Rear Port 2', type=PortTypeChoices.TYPE_8P8C), + RearPort(device=self.device, name='Rear Port 3', type=PortTypeChoices.TYPE_8P8C), + RearPort(device=self.device, name='Rear Port 4', type=PortTypeChoices.TYPE_8P8C), + RearPort(device=self.device, name='Rear Port 5', type=PortTypeChoices.TYPE_8P8C), + RearPort(device=self.device, name='Rear Port 6', type=PortTypeChoices.TYPE_8P8C), + )) + self.frontport1 = FrontPort.objects.create(device=self.device, name='Front Port 1', type=PortTypeChoices.TYPE_8P8C, rear_port=rear_ports[0]) + self.frontport3 = FrontPort.objects.create(device=self.device, name='Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rear_ports[1]) + self.frontport1 = FrontPort.objects.create(device=self.device, name='Front Port 3', type=PortTypeChoices.TYPE_8P8C, rear_port=rear_ports[2]) + + def test_get_frontport(self): + + url = reverse('dcim-api:frontport-detail', kwargs={'pk': self.frontport1.pk}) + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['name'], self.frontport1.name) + + def test_list_frontports(self): + + url = reverse('dcim-api:frontport-list') + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['count'], 3) + + def test_list_frontports_brief(self): + + url = reverse('dcim-api:frontport-list') + response = self.client.get('{}?brief=1'.format(url), **self.header) + + self.assertEqual( + sorted(response.data['results'][0]), + ['cable', 'device', 'id', 'name', 'url'] + ) + + def test_create_frontport(self): + + rear_port = RearPort.objects.get(name='Rear Port 4') + data = { + 'device': self.device.pk, + 'name': 'Front Port 4', + 'type': PortTypeChoices.TYPE_8P8C, + 'rear_port': rear_port.pk, + 'rear_port_position': 1, + } + + url = reverse('dcim-api:frontport-list') + response = self.client.post(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(FrontPort.objects.count(), 4) + frontport4 = FrontPort.objects.get(pk=response.data['id']) + self.assertEqual(frontport4.device_id, data['device']) + self.assertEqual(frontport4.name, data['name']) + + def test_create_frontport_bulk(self): + + rear_ports = RearPort.objects.filter(frontports__isnull=True) + data = [ + { + 'device': self.device.pk, + 'name': 'Front Port 4', + 'type': PortTypeChoices.TYPE_8P8C, + 'rear_port': rear_ports[0].pk, + 'rear_port_position': 1, + }, + { + 'device': self.device.pk, + 'name': 'Front Port 5', + 'type': PortTypeChoices.TYPE_8P8C, + 'rear_port': rear_ports[1].pk, + 'rear_port_position': 1, + }, + { + 'device': self.device.pk, + 'name': 'Front Port 6', + 'type': PortTypeChoices.TYPE_8P8C, + 'rear_port': rear_ports[2].pk, + 'rear_port_position': 1, + }, + ] + + url = reverse('dcim-api:frontport-list') + response = self.client.post(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(FrontPort.objects.count(), 6) + self.assertEqual(response.data[0]['name'], data[0]['name']) + self.assertEqual(response.data[1]['name'], data[1]['name']) + self.assertEqual(response.data[2]['name'], data[2]['name']) + + def test_update_frontport(self): + + rear_port = RearPort.objects.get(name='Rear Port 4') + data = { + 'device': self.device.pk, + 'name': 'Front Port X', + 'type': PortTypeChoices.TYPE_110_PUNCH, + 'rear_port': rear_port.pk, + 'rear_port_position': 1, + } + + url = reverse('dcim-api:frontport-detail', kwargs={'pk': self.frontport1.pk}) + response = self.client.put(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(FrontPort.objects.count(), 3) + frontport1 = FrontPort.objects.get(pk=response.data['id']) + self.assertEqual(frontport1.name, data['name']) + self.assertEqual(frontport1.type, data['type']) + self.assertEqual(frontport1.rear_port, rear_port) + + def test_delete_frontport(self): + + url = reverse('dcim-api:frontport-detail', kwargs={'pk': self.frontport1.pk}) + response = self.client.delete(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) + self.assertEqual(FrontPort.objects.count(), 2) + + +class RearPortTest(APITestCase): + + def setUp(self): + + super().setUp() + + site = Site.objects.create(name='Test Site 1', slug='test-site-1') + manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') + devicetype = DeviceType.objects.create( + manufacturer=manufacturer, model='Test Device Type 1', slug='test-device-type-1' + ) + devicerole = DeviceRole.objects.create( + name='Test Device Role 1', slug='test-device-role-1', color='ff0000' + ) + self.device = Device.objects.create( + device_type=devicetype, device_role=devicerole, name='Test Device 1', site=site + ) + self.rearport1 = RearPort.objects.create(device=self.device, type=PortTypeChoices.TYPE_8P8C, name='Rear Port 1') + self.rearport3 = RearPort.objects.create(device=self.device, type=PortTypeChoices.TYPE_8P8C, name='Rear Port 2') + self.rearport1 = RearPort.objects.create(device=self.device, type=PortTypeChoices.TYPE_8P8C, name='Rear Port 3') + + def test_get_rearport(self): + + url = reverse('dcim-api:rearport-detail', kwargs={'pk': self.rearport1.pk}) + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['name'], self.rearport1.name) + + def test_list_rearports(self): + + url = reverse('dcim-api:rearport-list') + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['count'], 3) + + def test_list_rearports_brief(self): + + url = reverse('dcim-api:rearport-list') + response = self.client.get('{}?brief=1'.format(url), **self.header) + + self.assertEqual( + sorted(response.data['results'][0]), + ['cable', 'device', 'id', 'name', 'url'] + ) + + def test_create_rearport(self): + + data = { + 'device': self.device.pk, + 'name': 'Front Port 4', + 'type': PortTypeChoices.TYPE_8P8C, + } + + url = reverse('dcim-api:rearport-list') + response = self.client.post(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(RearPort.objects.count(), 4) + rearport4 = RearPort.objects.get(pk=response.data['id']) + self.assertEqual(rearport4.device_id, data['device']) + self.assertEqual(rearport4.name, data['name']) + + def test_create_rearport_bulk(self): + + data = [ + { + 'device': self.device.pk, + 'name': 'Rear Port 4', + 'type': PortTypeChoices.TYPE_8P8C, + }, + { + 'device': self.device.pk, + 'name': 'Rear Port 5', + 'type': PortTypeChoices.TYPE_8P8C, + }, + { + 'device': self.device.pk, + 'name': 'Rear Port 6', + 'type': PortTypeChoices.TYPE_8P8C, + }, + ] + + url = reverse('dcim-api:rearport-list') + response = self.client.post(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(RearPort.objects.count(), 6) + self.assertEqual(response.data[0]['name'], data[0]['name']) + self.assertEqual(response.data[1]['name'], data[1]['name']) + self.assertEqual(response.data[2]['name'], data[2]['name']) + + def test_update_rearport(self): + + data = { + 'device': self.device.pk, + 'name': 'Front Port X', + 'type': PortTypeChoices.TYPE_110_PUNCH + } + + url = reverse('dcim-api:rearport-detail', kwargs={'pk': self.rearport1.pk}) + response = self.client.put(url, data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(RearPort.objects.count(), 3) + rearport1 = RearPort.objects.get(pk=response.data['id']) + self.assertEqual(rearport1.name, data['name']) + self.assertEqual(rearport1.type, data['type']) + + def test_delete_rearport(self): + + url = reverse('dcim-api:rearport-detail', kwargs={'pk': self.rearport1.pk}) + response = self.client.delete(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) + self.assertEqual(RearPort.objects.count(), 2) + + class DeviceBayTest(APITestCase): def setUp(self): diff --git a/netbox/dcim/tests/test_models.py b/netbox/dcim/tests/test_models.py index 7573d2cc4..32d864a51 100644 --- a/netbox/dcim/tests/test_models.py +++ b/netbox/dcim/tests/test_models.py @@ -285,7 +285,28 @@ class DeviceTestCase(TestCase): name='Device Bay 1' ) - def test_device_duplicate_name_per_site(self): + def test_multiple_unnamed_devices(self): + + device1 = Device( + site=self.site, + device_type=self.device_type, + device_role=self.device_role, + name='' + ) + device1.save() + + device2 = Device( + site=device1.site, + device_type=device1.device_type, + device_role=device1.device_role, + name='' + ) + device2.full_clean() + device2.save() + + self.assertEqual(Device.objects.filter(name='').count(), 2) + + def test_device_duplicate_names(self): device1 = Device( site=self.site, diff --git a/netbox/dcim/views.py b/netbox/dcim/views.py index e41d44d95..c83cee8fb 100644 --- a/netbox/dcim/views.py +++ b/netbox/dcim/views.py @@ -1945,6 +1945,12 @@ class CableCreateView(PermissionRequiredMixin, GetReturnURLMixin, View): # Parse initial data manually to avoid setting field values as lists initial_data = {k: request.GET[k] for k in request.GET} + # Set initial site and rack based on side A termination (if not already set) + if 'termination_b_site' not in initial_data: + initial_data['termination_b_site'] = getattr(self.obj.termination_a.parent, 'site', None) + if 'termination_b_rack' not in initial_data: + initial_data['termination_b_rack'] = getattr(self.obj.termination_a.parent, 'rack', None) + form = self.form_class(instance=self.obj, initial=initial_data) return render(request, self.template_name, { diff --git a/netbox/extras/scripts.py b/netbox/extras/scripts.py index fed003bed..bd7e864e1 100644 --- a/netbox/extras/scripts.py +++ b/netbox/extras/scripts.py @@ -14,10 +14,10 @@ from django.db import transaction from mptt.forms import TreeNodeChoiceField, TreeNodeMultipleChoiceField from mptt.models import MPTTModel -from ipam.formfields import IPFormField -from utilities.exceptions import AbortTransaction -from utilities.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator +from ipam.formfields import IPAddressFormField, IPNetworkFormField +from ipam.validators import MaxPrefixLengthValidator, MinPrefixLengthValidator, prefix_validator from .constants import LOG_DEFAULT, LOG_FAILURE, LOG_INFO, LOG_SUCCESS, LOG_WARNING +from utilities.exceptions import AbortTransaction from .forms import ScriptForm from .signals import purge_changelog @@ -27,6 +27,8 @@ __all__ = [ 'ChoiceVar', 'FileVar', 'IntegerVar', + 'IPAddressVar', + 'IPAddressWithMaskVar', 'IPNetworkVar', 'MultiObjectVar', 'ObjectVar', @@ -48,15 +50,19 @@ class ScriptVariable: def __init__(self, label='', description='', default=None, required=True): - # Default field attributes - self.field_attrs = { - 'help_text': description, - 'required': required - } + # Initialize field attributes + if not hasattr(self, 'field_attrs'): + self.field_attrs = {} + if description: + self.field_attrs['help_text'] = description if label: self.field_attrs['label'] = label if default: self.field_attrs['initial'] = default + if required: + self.field_attrs['required'] = True + if 'validators' not in self.field_attrs: + self.field_attrs['validators'] = [] def as_field(self): """ @@ -196,17 +202,32 @@ class FileVar(ScriptVariable): form_field = forms.FileField +class IPAddressVar(ScriptVariable): + """ + An IPv4 or IPv6 address without a mask. + """ + form_field = IPAddressFormField + + +class IPAddressWithMaskVar(ScriptVariable): + """ + An IPv4 or IPv6 address with a mask. + """ + form_field = IPNetworkFormField + + class IPNetworkVar(ScriptVariable): """ An IPv4 or IPv6 prefix. """ - form_field = IPFormField + form_field = IPNetworkFormField + field_attrs = { + 'validators': [prefix_validator] + } def __init__(self, min_prefix_length=None, max_prefix_length=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.field_attrs['validators'] = list() - # Optional minimum/maximum prefix lengths if min_prefix_length is not None: self.field_attrs['validators'].append( diff --git a/netbox/extras/tests/test_scripts.py b/netbox/extras/tests/test_scripts.py index 26e12772f..6237d1d95 100644 --- a/netbox/extras/tests/test_scripts.py +++ b/netbox/extras/tests/test_scripts.py @@ -1,6 +1,6 @@ from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase -from netaddr import IPNetwork +from netaddr import IPAddress, IPNetwork from dcim.models import DeviceRole from extras.scripts import * @@ -186,6 +186,54 @@ class ScriptVariablesTest(TestCase): self.assertTrue(form.is_valid()) self.assertEqual(form.cleaned_data['var1'], testfile) + def test_ipaddressvar(self): + + class TestScript(Script): + + var1 = IPAddressVar() + + # Validate IP network enforcement + data = {'var1': '1.2.3'} + form = TestScript().as_form(data, None) + self.assertFalse(form.is_valid()) + self.assertIn('var1', form.errors) + + # Validate IP mask exclusion + data = {'var1': '192.0.2.0/24'} + form = TestScript().as_form(data, None) + self.assertFalse(form.is_valid()) + self.assertIn('var1', form.errors) + + # Validate valid data + data = {'var1': '192.0.2.1'} + form = TestScript().as_form(data, None) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data['var1'], IPAddress(data['var1'])) + + def test_ipaddresswithmaskvar(self): + + class TestScript(Script): + + var1 = IPAddressWithMaskVar() + + # Validate IP network enforcement + data = {'var1': '1.2.3'} + form = TestScript().as_form(data, None) + self.assertFalse(form.is_valid()) + self.assertIn('var1', form.errors) + + # Validate IP mask requirement + data = {'var1': '192.0.2.0'} + form = TestScript().as_form(data, None) + self.assertFalse(form.is_valid()) + self.assertIn('var1', form.errors) + + # Validate valid data + data = {'var1': '192.0.2.0/24'} + form = TestScript().as_form(data, None) + self.assertTrue(form.is_valid()) + self.assertEqual(form.cleaned_data['var1'], IPNetwork(data['var1'])) + def test_ipnetworkvar(self): class TestScript(Script): @@ -198,6 +246,12 @@ class ScriptVariablesTest(TestCase): self.assertFalse(form.is_valid()) self.assertIn('var1', form.errors) + # Validate host IP check + data = {'var1': '192.0.2.1/24'} + form = TestScript().as_form(data, None) + self.assertFalse(form.is_valid()) + self.assertIn('var1', form.errors) + # Validate valid data data = {'var1': '192.0.2.0/24'} form = TestScript().as_form(data, None) diff --git a/netbox/extras/tests/test_webhooks.py b/netbox/extras/tests/test_webhooks.py index 02698b7dd..026a82bb8 100644 --- a/netbox/extras/tests/test_webhooks.py +++ b/netbox/extras/tests/test_webhooks.py @@ -1,11 +1,19 @@ +import json +import uuid +from unittest.mock import patch + import django_rq from django.contrib.contenttypes.models import ContentType +from django.http import HttpResponse from django.urls import reverse +from requests import Session from rest_framework import status from dcim.models import Site from extras.choices import ObjectChangeActionChoices from extras.models import Webhook +from extras.webhooks import enqueue_webhooks, generate_signature +from extras.webhooks_worker import process_webhook from utilities.testing import APITestCase @@ -22,11 +30,13 @@ class WebhookTest(APITestCase): def setUpTestData(cls): site_ct = ContentType.objects.get_for_model(Site) - PAYLOAD_URL = "http://localhost/" + DUMMY_URL = "http://localhost/" + DUMMY_SECRET = "LOOKATMEIMASECRETSTRING" + webhooks = Webhook.objects.bulk_create(( - Webhook(name='Site Create Webhook', type_create=True, payload_url=PAYLOAD_URL), - Webhook(name='Site Update Webhook', type_update=True, payload_url=PAYLOAD_URL), - Webhook(name='Site Delete Webhook', type_delete=True, payload_url=PAYLOAD_URL), + Webhook(name='Site Create Webhook', type_create=True, payload_url=DUMMY_URL, secret=DUMMY_SECRET, additional_headers={'X-Foo': 'Bar'}), + Webhook(name='Site Update Webhook', type_update=True, payload_url=DUMMY_URL, secret=DUMMY_SECRET), + Webhook(name='Site Delete Webhook', type_delete=True, payload_url=DUMMY_URL, secret=DUMMY_SECRET), )) for webhook in webhooks: webhook.obj_type.set([site_ct]) @@ -87,3 +97,47 @@ class WebhookTest(APITestCase): self.assertEqual(job.args[1]['id'], site.pk) self.assertEqual(job.args[2], 'site') self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_DELETE) + + def test_webhooks_worker(self): + + request_id = uuid.uuid4() + + def dummy_send(_, request): + """ + A dummy implementation of Session.send() to be used for testing. + Always returns a 200 HTTP response. + """ + webhook = Webhook.objects.get(type_create=True) + signature = generate_signature(request.body, webhook.secret) + + # Validate the outgoing request headers + self.assertEqual(request.headers['Content-Type'], webhook.http_content_type) + self.assertEqual(request.headers['X-Hook-Signature'], signature) + self.assertEqual(request.headers['X-Foo'], 'Bar') + + # Validate the outgoing request body + body = json.loads(request.body) + self.assertEqual(body['event'], 'created') + self.assertEqual(body['timestamp'], job.args[4]) + self.assertEqual(body['model'], 'site') + self.assertEqual(body['username'], 'testuser') + self.assertEqual(body['request_id'], str(request_id)) + self.assertEqual(body['data']['name'], 'Site 1') + + return HttpResponse() + + # Enqueue a webhook for processing + site = Site.objects.create(name='Site 1', slug='site-1') + enqueue_webhooks( + instance=site, + user=self.user, + request_id=request_id, + action=ObjectChangeActionChoices.ACTION_CREATE + ) + + # Retrieve the job from queue + job = self.queue.jobs[0] + + # Patch the Session object with our dummy_send() method, then process the webhook for sending + with patch.object(Session, 'send', dummy_send) as mock_send: + process_webhook(*job.args) diff --git a/netbox/extras/webhooks.py b/netbox/extras/webhooks.py index 5017582cc..04eca4dfe 100644 --- a/netbox/extras/webhooks.py +++ b/netbox/extras/webhooks.py @@ -1,4 +1,6 @@ import datetime +import hashlib +import hmac from django.contrib.contenttypes.models import ContentType @@ -8,6 +10,18 @@ from .choices import * from .constants import * +def generate_signature(request_body, secret): + """ + Return a cryptographic signature that can be used to verify the authenticity of webhook data. + """ + hmac_prep = hmac.new( + key=secret.encode('utf8'), + msg=request_body.encode('utf8'), + digestmod=hashlib.sha512 + ) + return hmac_prep.hexdigest() + + def enqueue_webhooks(instance, user, request_id, action): """ Find Webhook(s) assigned to this instance + action and enqueue them diff --git a/netbox/extras/webhooks_worker.py b/netbox/extras/webhooks_worker.py index 6f7ede4e4..e48d8a2d7 100644 --- a/netbox/extras/webhooks_worker.py +++ b/netbox/extras/webhooks_worker.py @@ -1,5 +1,3 @@ -import hashlib -import hmac import json import requests @@ -7,6 +5,7 @@ from django_rq import job from rest_framework.utils.encoders import JSONEncoder from .choices import ObjectChangeActionChoices, WebhookContentTypeChoices +from .webhooks import generate_signature @job('default') @@ -23,7 +22,7 @@ def process_webhook(webhook, data, model_name, event, timestamp, username, reque 'data': data } headers = { - 'Content-Type': webhook.get_http_content_type_display(), + 'Content-Type': webhook.http_content_type, } if webhook.additional_headers: headers.update(webhook.additional_headers) @@ -43,12 +42,7 @@ def process_webhook(webhook, data, model_name, event, timestamp, username, reque if webhook.secret != '': # Sign the request with a hash of the secret key and its content. - hmac_prep = hmac.new( - key=webhook.secret.encode('utf8'), - msg=prepared_request.body.encode('utf8'), - digestmod=hashlib.sha512 - ) - prepared_request.headers['X-Hook-Signature'] = hmac_prep.hexdigest() + prepared_request.headers['X-Hook-Signature'] = generate_signature(prepared_request.body, webhook.secret) with requests.Session() as session: session.verify = webhook.ssl_verification @@ -56,7 +50,7 @@ def process_webhook(webhook, data, model_name, event, timestamp, username, reque session.verify = webhook.ca_file_path response = session.send(prepared_request) - if response.status_code >= 200 and response.status_code <= 299: + if 200 <= response.status_code <= 299: return 'Status {} returned, webhook successfully processed.'.format(response.status_code) else: raise requests.exceptions.RequestException( diff --git a/netbox/ipam/fields.py b/netbox/ipam/fields.py index 72600d1b9..456a7debc 100644 --- a/netbox/ipam/fields.py +++ b/netbox/ipam/fields.py @@ -2,13 +2,8 @@ from django.core.exceptions import ValidationError from django.db import models from netaddr import AddrFormatError, IPNetwork -from . import lookups -from .formfields import IPFormField - - -def prefix_validator(prefix): - if prefix.ip != prefix.cidr.ip: - raise ValidationError("{} is not a valid prefix. Did you mean {}?".format(prefix, prefix.cidr)) +from . import lookups, validators +from .formfields import IPNetworkFormField class BaseIPField(models.Field): @@ -38,7 +33,7 @@ class BaseIPField(models.Field): return str(self.to_python(value)) def form_class(self): - return IPFormField + return IPNetworkFormField def formfield(self, **kwargs): defaults = {'form_class': self.form_class()} @@ -51,7 +46,7 @@ class IPNetworkField(BaseIPField): IP prefix (network and mask) """ description = "PostgreSQL CIDR field" - default_validators = [prefix_validator] + default_validators = [validators.prefix_validator] def db_type(self, connection): return 'cidr' diff --git a/netbox/ipam/formfields.py b/netbox/ipam/formfields.py index 2909a54b1..e8d171d7f 100644 --- a/netbox/ipam/formfields.py +++ b/netbox/ipam/formfields.py @@ -1,13 +1,44 @@ from django import forms from django.core.exceptions import ValidationError -from netaddr import IPNetwork, AddrFormatError +from django.core.validators import validate_ipv4_address, validate_ipv6_address +from netaddr import IPAddress, IPNetwork, AddrFormatError # # Form fields # -class IPFormField(forms.Field): +class IPAddressFormField(forms.Field): + default_error_messages = { + 'invalid': "Enter a valid IPv4 or IPv6 address (without a mask).", + } + + def to_python(self, value): + if not value: + return None + + if isinstance(value, IPAddress): + return value + + # netaddr is a bit too liberal with what it accepts as a valid IP address. For example, '1.2.3' will become + # IPAddress('1.2.0.3'). Here, we employ Django's built-in IPv4 and IPv6 address validators as a sanity check. + try: + validate_ipv4_address(value) + except ValidationError: + try: + validate_ipv6_address(value) + except ValidationError: + raise ValidationError("Invalid IPv4/IPv6 address format: {}".format(value)) + + try: + return IPAddress(value) + except ValueError: + raise ValidationError('This field requires an IP address without a mask.') + except AddrFormatError: + raise ValidationError("Please specify a valid IPv4 or IPv6 address.") + + +class IPNetworkFormField(forms.Field): default_error_messages = { 'invalid': "Enter a valid IPv4 or IPv6 address (with CIDR mask).", } diff --git a/netbox/ipam/validators.py b/netbox/ipam/validators.py index 960675643..879e20e6a 100644 --- a/netbox/ipam/validators.py +++ b/netbox/ipam/validators.py @@ -1,4 +1,26 @@ -from django.core.validators import RegexValidator +from django.core.exceptions import ValidationError +from django.core.validators import BaseValidator, RegexValidator + + +def prefix_validator(prefix): + if prefix.ip != prefix.cidr.ip: + raise ValidationError("{} is not a valid prefix. Did you mean {}?".format(prefix, prefix.cidr)) + + +class MaxPrefixLengthValidator(BaseValidator): + message = 'The prefix length must be less than or equal to %(limit_value)s.' + code = 'max_prefix_length' + + def compare(self, a, b): + return a.prefixlen > b + + +class MinPrefixLengthValidator(BaseValidator): + message = 'The prefix length must be greater than or equal to %(limit_value)s.' + code = 'min_prefix_length' + + def compare(self, a, b): + return a.prefixlen < b DNSValidator = RegexValidator( diff --git a/netbox/secrets/tests/test_form.py b/netbox/secrets/tests/test_form.py index 42111abbf..d122358cc 100644 --- a/netbox/secrets/tests/test_form.py +++ b/netbox/secrets/tests/test_form.py @@ -29,5 +29,4 @@ class UserKeyFormTestCase(TestCase): data={'public_key': SSH_PUBLIC_KEY}, instance=self.userkey, ) - print(form.is_valid()) self.assertFalse(form.is_valid()) diff --git a/netbox/templates/dcim/cable_connect.html b/netbox/templates/dcim/cable_connect.html index b1609f578..aa4c4bf8c 100644 --- a/netbox/templates/dcim/cable_connect.html +++ b/netbox/templates/dcim/cable_connect.html @@ -144,25 +144,8 @@