mirror of
https://github.com/netbox-community/netbox.git
synced 2025-12-18 19:32:24 -06:00
Merge branch 'develop' into fix/generic_prefetch_4.2
This commit is contained in:
@@ -129,7 +129,7 @@ def get_annotations_for_serializer(serializer_class, fields_to_include=None):
|
||||
|
||||
for field_name, field in serializer_class._declared_fields.items():
|
||||
if field_name in fields_to_include and type(field) is RelatedObjectCountField:
|
||||
related_field = model._meta.get_field(field.relation).field
|
||||
related_field = getattr(model, field.relation).field
|
||||
annotations[field_name] = count_related(related_field.model, related_field.name)
|
||||
|
||||
return annotations
|
||||
|
||||
@@ -93,3 +93,7 @@ HTML_ALLOWED_ATTRIBUTES = {
|
||||
"td": {"align"},
|
||||
"th": {"align"},
|
||||
}
|
||||
|
||||
HTTP_PROXY_SUPPORTED_SOCK_SCHEMAS = ['socks4', 'socks4a', 'socks4h', 'socks5', 'socks5a', 'socks5h']
|
||||
HTTP_PROXY_SOCK_RDNS_SCHEMAS = ['socks4h', 'socks4a', 'socks5h', 'socks5a']
|
||||
HTTP_PROXY_SUPPORTED_SCHEMAS = ['http', 'https', 'socks4', 'socks4a', 'socks4h', 'socks5', 'socks5a', 'socks5h']
|
||||
|
||||
@@ -2,7 +2,8 @@ from decimal import Decimal
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from dcim.choices import CableLengthUnitChoices, WeightUnitChoices
|
||||
from dcim.choices import CableLengthUnitChoices
|
||||
from netbox.choices import WeightUnitChoices
|
||||
|
||||
__all__ = (
|
||||
'to_grams',
|
||||
@@ -10,9 +11,9 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def to_grams(weight, unit):
|
||||
def to_grams(weight, unit) -> int:
|
||||
"""
|
||||
Convert the given weight to kilograms.
|
||||
Convert the given weight to integer grams.
|
||||
"""
|
||||
try:
|
||||
if weight < 0:
|
||||
@@ -21,13 +22,13 @@ def to_grams(weight, unit):
|
||||
raise TypeError(_("Invalid value '{weight}' for weight (must be a number)").format(weight=weight))
|
||||
|
||||
if unit == WeightUnitChoices.UNIT_KILOGRAM:
|
||||
return weight * 1000
|
||||
return int(weight * 1000)
|
||||
if unit == WeightUnitChoices.UNIT_GRAM:
|
||||
return weight
|
||||
return int(weight)
|
||||
if unit == WeightUnitChoices.UNIT_POUND:
|
||||
return weight * Decimal(453.592)
|
||||
return int(weight * Decimal(453.592))
|
||||
if unit == WeightUnitChoices.UNIT_OUNCE:
|
||||
return weight * Decimal(28.3495)
|
||||
return int(weight * Decimal(28.3495))
|
||||
raise ValueError(
|
||||
_("Unknown unit {unit}. Must be one of the following: {valid_units}").format(
|
||||
unit=unit,
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import decimal
|
||||
from django.db.backends.postgresql.psycopg_any import NumericRange
|
||||
from itertools import count, groupby
|
||||
|
||||
__all__ = (
|
||||
'array_to_ranges',
|
||||
'array_to_string',
|
||||
'check_ranges_overlap',
|
||||
'deepmerge',
|
||||
'drange',
|
||||
'flatten_dict',
|
||||
'ranges_to_string',
|
||||
'shallow_compare_dict',
|
||||
'string_to_ranges',
|
||||
)
|
||||
|
||||
|
||||
@@ -113,3 +117,52 @@ def drange(start, end, step=decimal.Decimal(1)):
|
||||
while start > end:
|
||||
yield start
|
||||
start += step
|
||||
|
||||
|
||||
def check_ranges_overlap(ranges):
|
||||
"""
|
||||
Check for overlap in an iterable of NumericRanges.
|
||||
"""
|
||||
ranges.sort(key=lambda x: x.lower)
|
||||
|
||||
for i in range(1, len(ranges)):
|
||||
prev_range = ranges[i - 1]
|
||||
prev_upper = prev_range.upper if prev_range.upper_inc else prev_range.upper - 1
|
||||
lower = ranges[i].lower if ranges[i].lower_inc else ranges[i].lower + 1
|
||||
if prev_upper >= lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ranges_to_string(ranges):
|
||||
"""
|
||||
Generate a human-friendly string from a set of ranges. Intended for use with ArrayField. For example:
|
||||
[[1, 100)], [200, 300)] => "1-99,200-299"
|
||||
"""
|
||||
if not ranges:
|
||||
return ''
|
||||
output = []
|
||||
for r in ranges:
|
||||
lower = r.lower if r.lower_inc else r.lower + 1
|
||||
upper = r.upper if r.upper_inc else r.upper - 1
|
||||
output.append(f'{lower}-{upper}')
|
||||
return ','.join(output)
|
||||
|
||||
|
||||
def string_to_ranges(value):
|
||||
"""
|
||||
Given a string in the format "1-100, 200-300" return an list of NumericRanges. Intended for use with ArrayField.
|
||||
For example:
|
||||
"1-99,200-299" => [NumericRange(1, 100), NumericRange(200, 300)]
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
value.replace(' ', '') # Remove whitespace
|
||||
values = []
|
||||
for dash_range in value.split(','):
|
||||
if '-' not in dash_range:
|
||||
return None
|
||||
lower, upper = dash_range.split('-')
|
||||
values.append(NumericRange(int(lower), int(upper), bounds='[]'))
|
||||
return values
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import datetime
|
||||
|
||||
from django.utils import timezone
|
||||
from django.utils.timezone import localtime
|
||||
|
||||
__all__ = (
|
||||
'datetime_from_timestamp',
|
||||
'local_now',
|
||||
)
|
||||
|
||||
@@ -11,3 +14,15 @@ def local_now():
|
||||
Return the current date & time in the system timezone.
|
||||
"""
|
||||
return localtime(timezone.now())
|
||||
|
||||
|
||||
def datetime_from_timestamp(value):
|
||||
"""
|
||||
Convert an ISO 8601 or RFC 3339 timestamp to a datetime object.
|
||||
"""
|
||||
# Work around UTC issue for Python < 3.11; see
|
||||
# https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat
|
||||
# TODO: Remove this once Python 3.10 is no longer supported
|
||||
if type(value) is str and value.endswith('Z'):
|
||||
value = f'{value[:-1]}+00:00'
|
||||
return datetime.datetime.fromisoformat(value)
|
||||
|
||||
@@ -39,7 +39,7 @@ def handle_protectederror(obj_list, request, e):
|
||||
if hasattr(dependent, 'get_absolute_url'):
|
||||
dependent_objects.append(f'<a href="{dependent.get_absolute_url()}">{escape(dependent)}</a>')
|
||||
else:
|
||||
dependent_objects.append(str(dependent))
|
||||
dependent_objects.append(escape(str(dependent)))
|
||||
err_message += ', '.join(dependent_objects)
|
||||
|
||||
messages.error(request, mark_safe(err_message))
|
||||
@@ -49,11 +49,11 @@ def handle_rest_api_exception(request, *args, **kwargs):
|
||||
"""
|
||||
Handle exceptions and return a useful error message for REST API requests.
|
||||
"""
|
||||
type_, error, traceback = sys.exc_info()
|
||||
type_, error = sys.exc_info()[:2]
|
||||
data = {
|
||||
'error': str(error),
|
||||
'exception': type_.__name__,
|
||||
'netbox_version': settings.VERSION,
|
||||
'netbox_version': settings.RELEASE.full_version,
|
||||
'python_version': platform.python_version(),
|
||||
}
|
||||
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
@@ -2,9 +2,9 @@ from collections import defaultdict
|
||||
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.db import models
|
||||
from django.utils.safestring import mark_safe
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from utilities.ordering import naturalize
|
||||
from .forms.widgets import ColorSelect
|
||||
from .validators import ColorValidator
|
||||
|
||||
@@ -26,6 +26,7 @@ class ColorField(models.CharField):
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs['widget'] = ColorSelect
|
||||
kwargs['help_text'] = mark_safe(_('RGB color in hexadecimal. Example: ') + '<code>00ff00</code>')
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class NaturalOrderingField(models.CharField):
|
||||
"""
|
||||
description = "Stores a representation of its target field suitable for natural ordering"
|
||||
|
||||
def __init__(self, target_field, naturalize_function=naturalize, *args, **kwargs):
|
||||
def __init__(self, target_field, naturalize_function, *args, **kwargs):
|
||||
self.target_field = target_field
|
||||
self.naturalize_function = naturalize_function
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -70,14 +71,24 @@ class RestrictedGenericForeignKey(GenericForeignKey):
|
||||
# 1. Capture restrict_params from RestrictedPrefetch (hack)
|
||||
# 2. If restrict_params is set, call restrict() on the queryset for
|
||||
# the related model
|
||||
def get_prefetch_queryset(self, instances, queryset=None):
|
||||
def get_prefetch_querysets(self, instances, querysets=None):
|
||||
restrict_params = {}
|
||||
custom_queryset_dict = {}
|
||||
|
||||
# Compensate for the hack in RestrictedPrefetch
|
||||
if type(queryset) is dict:
|
||||
restrict_params = queryset
|
||||
elif queryset is not None:
|
||||
raise ValueError(_("Custom queryset can't be used for this lookup."))
|
||||
if type(querysets) is dict:
|
||||
restrict_params = querysets
|
||||
|
||||
elif querysets is not None:
|
||||
for queryset in querysets:
|
||||
ct_id = self.get_content_type(
|
||||
model=queryset.query.model, using=queryset.db
|
||||
).pk
|
||||
if ct_id in custom_queryset_dict:
|
||||
raise ValueError(
|
||||
"Only one queryset is allowed for each content type."
|
||||
)
|
||||
custom_queryset_dict[ct_id] = queryset
|
||||
|
||||
# For efficiency, group the instances by content type and then do one
|
||||
# query per model
|
||||
@@ -100,15 +111,16 @@ class RestrictedGenericForeignKey(GenericForeignKey):
|
||||
|
||||
ret_val = []
|
||||
for ct_id, fkeys in fk_dict.items():
|
||||
instance = instance_dict[ct_id]
|
||||
ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
||||
if restrict_params:
|
||||
# Override the default behavior to call restrict() on each model's queryset
|
||||
qs = ct.model_class().objects.filter(pk__in=fkeys).restrict(**restrict_params)
|
||||
ret_val.extend(qs)
|
||||
if ct_id in custom_queryset_dict:
|
||||
# Return values from the custom queryset, if provided.
|
||||
ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
|
||||
else:
|
||||
# Default behavior
|
||||
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
|
||||
instance = instance_dict[ct_id]
|
||||
ct = self.get_content_type(id=ct_id, using=instance._state.db)
|
||||
qs = ct.model_class().objects.filter(pk__in=fkeys)
|
||||
if restrict_params:
|
||||
qs = qs.restrict(**restrict_params)
|
||||
ret_val.extend(qs)
|
||||
|
||||
# For doing the join in Python, we have to match both the FK val and the
|
||||
# content type, so we use a callable that returns a (fk, class) pair.
|
||||
|
||||
@@ -3,8 +3,8 @@ from django import forms
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django_filters.constants import EMPTY_VALUES
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
|
||||
__all__ = (
|
||||
'ContentTypeFilter',
|
||||
@@ -116,6 +116,7 @@ class MultiValueWWNFilter(django_filters.MultipleChoiceFilter):
|
||||
field_class = multivalue_field_factory(forms.CharField)
|
||||
|
||||
|
||||
@extend_schema_field(OpenApiTypes.STR)
|
||||
class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
|
||||
"""
|
||||
Filters for a set of Models, including all descendant models within a Tree. Example: [<Region: R1>,<Region: R2>]
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from django import forms
|
||||
from django.contrib.postgres.forms import SimpleArrayField
|
||||
from django.utils.safestring import mark_safe
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from utilities.data import ranges_to_string, string_to_ranges
|
||||
|
||||
from ..utils import parse_numeric_range
|
||||
|
||||
__all__ = (
|
||||
'NumericArrayField',
|
||||
'NumericRangeArrayField',
|
||||
)
|
||||
|
||||
|
||||
@@ -24,3 +27,31 @@ class NumericArrayField(SimpleArrayField):
|
||||
if isinstance(value, str):
|
||||
value = ','.join([str(n) for n in parse_numeric_range(value)])
|
||||
return super().to_python(value)
|
||||
|
||||
|
||||
class NumericRangeArrayField(forms.CharField):
|
||||
"""
|
||||
A field which allows for array of numeric ranges:
|
||||
Example: 1-5,7-20,30-50
|
||||
"""
|
||||
def __init__(self, *args, help_text='', **kwargs):
|
||||
if not help_text:
|
||||
help_text = mark_safe(
|
||||
_("Specify one or more numeric ranges separated by commas. Example: " + "<code>1-5,20-30</code>")
|
||||
)
|
||||
super().__init__(*args, help_text=help_text, **kwargs)
|
||||
|
||||
def clean(self, value):
|
||||
if value and not self.to_python(value):
|
||||
raise forms.ValidationError(
|
||||
_("Invalid ranges ({value}). Must be a range of integers in ascending order.").format(value=value)
|
||||
)
|
||||
return super().clean(value)
|
||||
|
||||
def prepare_value(self, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return ranges_to_string(value)
|
||||
|
||||
def to_python(self, value):
|
||||
return string_to_ranges(value)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
|
||||
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist, FieldError
|
||||
from django.db.models import Q
|
||||
|
||||
from utilities.choices import unpack_grouped_choices
|
||||
@@ -64,6 +64,10 @@ class CSVModelChoiceField(forms.ModelChoiceField):
|
||||
raise forms.ValidationError(
|
||||
_('"{value}" is not a unique value for this field; multiple objects were found').format(value=value)
|
||||
)
|
||||
except FieldError:
|
||||
raise forms.ValidationError(
|
||||
_('"{field_name}" is an invalid accessor field name.').format(field_name=self.to_field_name)
|
||||
)
|
||||
|
||||
|
||||
class CSVModelMultipleChoiceField(forms.ModelMultipleChoiceField):
|
||||
|
||||
@@ -2,7 +2,7 @@ import django_filters
|
||||
from django import forms
|
||||
from django.conf import settings
|
||||
from django.forms import BoundField
|
||||
from django.urls import reverse
|
||||
from django.urls import reverse, reverse_lazy
|
||||
|
||||
from utilities.forms import widgets
|
||||
from utilities.views import get_viewname
|
||||
@@ -66,6 +66,10 @@ class DynamicModelChoiceMixin:
|
||||
choice (DEPRECATED: pass `context={'disabled': '$fieldname'}` instead)
|
||||
context: A mapping of <option> template variables to their API data keys (optional; see below)
|
||||
selector: Include an advanced object selection widget to assist the user in identifying the desired object
|
||||
quick_add: Include a widget to quickly create a new related object for assignment. NOTE: Nested usage of
|
||||
quick-add fields is not currently supported.
|
||||
quick_add_params: A dictionary of initial data to include when launching the quick-add form (optional). The
|
||||
token string "$pk" will be replaced with the primary key of the form's instance, if any.
|
||||
|
||||
Context keys:
|
||||
value: The name of the attribute which contains the option's value (default: 'id')
|
||||
@@ -90,6 +94,8 @@ class DynamicModelChoiceMixin:
|
||||
disabled_indicator=None,
|
||||
context=None,
|
||||
selector=False,
|
||||
quick_add=False,
|
||||
quick_add_params=None,
|
||||
**kwargs
|
||||
):
|
||||
self.model = queryset.model
|
||||
@@ -99,6 +105,8 @@ class DynamicModelChoiceMixin:
|
||||
self.disabled_indicator = disabled_indicator
|
||||
self.context = context or {}
|
||||
self.selector = selector
|
||||
self.quick_add = quick_add
|
||||
self.quick_add_params = quick_add_params or {}
|
||||
|
||||
super().__init__(queryset, **kwargs)
|
||||
|
||||
@@ -113,11 +121,6 @@ class DynamicModelChoiceMixin:
|
||||
for var, accessor in self.context.items():
|
||||
attrs[f'ts-{var}-field'] = accessor
|
||||
|
||||
# TODO: Remove in v4.1
|
||||
# Legacy means of specifying the disabled indicator
|
||||
if self.disabled_indicator is not None:
|
||||
attrs['ts-disabled-field'] = self.disabled_indicator
|
||||
|
||||
# Attach any static query parameters
|
||||
if len(self.query_params) > 0:
|
||||
widget.add_query_params(self.query_params)
|
||||
@@ -147,7 +150,7 @@ class DynamicModelChoiceMixin:
|
||||
|
||||
if data:
|
||||
# When the field is multiple choice pass the data as a list if it's not already
|
||||
if isinstance(bound_field.field, DynamicModelMultipleChoiceField) and not type(data) is list:
|
||||
if isinstance(bound_field.field, DynamicModelMultipleChoiceField) and type(data) is not list:
|
||||
data = [data]
|
||||
|
||||
field_name = getattr(self, 'to_field_name') or 'pk'
|
||||
@@ -166,6 +169,22 @@ class DynamicModelChoiceMixin:
|
||||
viewname = get_viewname(self.queryset.model, action='list', rest_api=True)
|
||||
widget.attrs['data-url'] = reverse(viewname)
|
||||
|
||||
# Include quick add?
|
||||
if self.quick_add:
|
||||
app_label = self.model._meta.app_label
|
||||
model_name = self.model._meta.model_name
|
||||
widget.quick_add_context = {
|
||||
'url': reverse_lazy(f'{app_label}:{model_name}_add'),
|
||||
'params': {},
|
||||
}
|
||||
for k, v in self.quick_add_params.items():
|
||||
if v == '$pk':
|
||||
# Replace "$pk" token with the primary key of the form's instance (if any)
|
||||
if getattr(form.instance, 'pk', None):
|
||||
widget.quick_add_context['params'][k] = form.instance.pk
|
||||
else:
|
||||
widget.quick_add_context['params'][k] = v
|
||||
|
||||
return bound_field
|
||||
|
||||
|
||||
@@ -197,6 +216,6 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip
|
||||
# string 'null'. This will check for that condition and gracefully handle the conversion to a NoneType.
|
||||
if self.null_option is not None and settings.FILTERS_NULL_CHOICE_VALUE in value:
|
||||
value = [v for v in value if v != settings.FILTERS_NULL_CHOICE_VALUE]
|
||||
return [None, *value]
|
||||
return [None, *super().clean(value)]
|
||||
|
||||
return super().clean(value)
|
||||
|
||||
@@ -22,6 +22,15 @@ class APISelect(forms.Select):
|
||||
dynamic_params: Dict[str, str]
|
||||
static_params: Dict[str, List[str]]
|
||||
|
||||
def get_context(self, name, value, attrs):
|
||||
context = super().get_context(name, value, attrs)
|
||||
|
||||
# Add quick-add context data, if enabled for the widget
|
||||
if hasattr(self, 'quick_add_context'):
|
||||
context['quick_add'] = self.quick_add_context
|
||||
|
||||
return context
|
||||
|
||||
def __init__(self, api_url=None, full=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -43,9 +43,12 @@ class HTMXSelect(forms.Select):
|
||||
"""
|
||||
Selection widget that will re-generate the HTML form upon the selection of a new option.
|
||||
"""
|
||||
def __init__(self, hx_url='.', hx_target_id='form_fields', attrs=None, **kwargs):
|
||||
def __init__(self, method='get', hx_url='.', hx_target_id='form_fields', attrs=None, **kwargs):
|
||||
method = method.lower()
|
||||
if method not in ('delete', 'get', 'patch', 'post', 'put'):
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
_attrs = {
|
||||
'hx-get': hx_url,
|
||||
f'hx-{method}': hx_url,
|
||||
'hx-include': f'#{hx_target_id}',
|
||||
'hx-target': f'#{hx_target_id}',
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ def highlight(value, highlight, trim_pre=None, trim_post=None, trim_placeholder=
|
||||
else:
|
||||
highlight = re.escape(highlight)
|
||||
pre, match, post = re.split(fr'({highlight})', value, maxsplit=1, flags=re.IGNORECASE)
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
# Match not found
|
||||
return escape(value)
|
||||
|
||||
|
||||
@@ -28,10 +28,14 @@ class DataFileLoader(BaseLoader):
|
||||
raise TemplateNotFound(template)
|
||||
|
||||
# Find and pre-fetch referenced templates
|
||||
if referenced_templates := find_referenced_templates(environment.parse(template_source)):
|
||||
if referenced_templates := tuple(find_referenced_templates(environment.parse(template_source))):
|
||||
related_files = DataFile.objects.filter(source=self.data_source)
|
||||
# None indicates the use of dynamic resolution. If dependent files are statically
|
||||
# defined, we can filter by path for optimization.
|
||||
if None not in referenced_templates:
|
||||
related_files = related_files.filter(path__in=referenced_templates)
|
||||
self.cache_templates({
|
||||
df.path: df.data_as_string for df in
|
||||
DataFile.objects.filter(source=self.data_source, path__in=referenced_templates)
|
||||
df.path: df.data_as_string for df in related_files
|
||||
})
|
||||
|
||||
return template_source, template, lambda: True
|
||||
|
||||
@@ -3,6 +3,7 @@ import decimal
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
|
||||
__all__ = (
|
||||
'ConfigJSONEncoder',
|
||||
'CustomFieldJSONEncoder',
|
||||
)
|
||||
|
||||
@@ -15,3 +16,16 @@ class CustomFieldJSONEncoder(DjangoJSONEncoder):
|
||||
if isinstance(o, decimal.Decimal):
|
||||
return float(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class ConfigJSONEncoder(DjangoJSONEncoder):
|
||||
"""
|
||||
Override Django's built-in JSON encoder to serialize CustomValidator classes as strings.
|
||||
"""
|
||||
def default(self, o):
|
||||
from extras.validators import CustomValidator
|
||||
|
||||
if issubclass(type(o), CustomValidator):
|
||||
return type(o).__name__
|
||||
|
||||
return super().default(o)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db.models import Count, OuterRef, Subquery
|
||||
|
||||
from netbox.registry import registry
|
||||
from utilities.counters import update_counts
|
||||
|
||||
@@ -14,7 +14,7 @@ class StrikethroughExtension(markdown.Extension):
|
||||
"""
|
||||
def extendMarkdown(self, md):
|
||||
md.inlinePatterns.register(
|
||||
markdown.inlinepatterns.SimpleTagPattern(STRIKE_RE, 'del'),
|
||||
SimpleTagPattern(STRIKE_RE, 'del'),
|
||||
'strikethrough',
|
||||
200
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ def get_paginate_count(request):
|
||||
pass
|
||||
|
||||
if request.user.is_authenticated:
|
||||
per_page = request.user.config.get('pagination.per_page', config.PAGINATE_COUNT)
|
||||
per_page = request.user.config.get('pagination.per_page') or config.PAGINATE_COUNT
|
||||
return _max_allowed(per_page)
|
||||
|
||||
return _max_allowed(config.PAGINATE_COUNT)
|
||||
|
||||
27
netbox/utilities/password_validation.py
Normal file
27
netbox/utilities/password_validation.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
||||
class AlphanumericPasswordValidator:
|
||||
"""
|
||||
Validate that the password has at least one numeral, one uppercase letter and one lowercase letter.
|
||||
"""
|
||||
|
||||
def validate(self, password, user=None):
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise ValidationError(
|
||||
_("Password must have at least one numeral."),
|
||||
)
|
||||
|
||||
if not any(char.isupper() for char in password):
|
||||
raise ValidationError(
|
||||
_("Password must have at least one uppercase letter."),
|
||||
)
|
||||
|
||||
if not any(char.islower() for char in password):
|
||||
raise ValidationError(
|
||||
_("Password must have at least one lowercase letter."),
|
||||
)
|
||||
|
||||
def get_help_text(self):
|
||||
return _("Your password must contain at least one numeral, one uppercase letter and one lowercase letter.")
|
||||
@@ -1,7 +1,10 @@
|
||||
from django.conf import settings
|
||||
from django.apps import apps
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from users.constants import CONSTRAINT_TOKEN_USER
|
||||
|
||||
__all__ = (
|
||||
'get_permission_for_model',
|
||||
'permission_is_exempt',
|
||||
@@ -90,6 +93,11 @@ def qs_filter_from_constraints(constraints, tokens=None):
|
||||
if tokens is None:
|
||||
tokens = {}
|
||||
|
||||
User = apps.get_model('users.User')
|
||||
for token, value in tokens.items():
|
||||
if token == CONSTRAINT_TOKEN_USER and isinstance(value, User):
|
||||
tokens[token] = value.id
|
||||
|
||||
def _replace_tokens(value, tokens):
|
||||
if type(value) is list:
|
||||
return list(map(lambda v: tokens.get(v, v), value))
|
||||
|
||||
@@ -55,7 +55,7 @@ def prepare_cloned_fields(instance):
|
||||
for key, value in attrs.items():
|
||||
if type(value) in (list, tuple):
|
||||
params.extend([(key, v) for v in value])
|
||||
elif value not in (False, None):
|
||||
elif value is not False and value is not None:
|
||||
params.append((key, value))
|
||||
else:
|
||||
params.append((key, ''))
|
||||
|
||||
@@ -20,14 +20,14 @@ class RestrictedPrefetch(Prefetch):
|
||||
|
||||
super().__init__(lookup, queryset=queryset, to_attr=to_attr)
|
||||
|
||||
def get_current_queryset(self, level):
|
||||
def get_current_querysets(self, level):
|
||||
params = {
|
||||
'user': self.restrict_user,
|
||||
'action': self.restrict_action,
|
||||
}
|
||||
|
||||
if qs := super().get_current_queryset(level):
|
||||
return qs.restrict(**params)
|
||||
if querysets := super().get_current_querysets(level):
|
||||
return [qs.restrict(**params) for qs in querysets]
|
||||
|
||||
# Bit of a hack. If no queryset is defined, pass through the dict of restrict()
|
||||
# kwargs to be handled by the field. This is necessary e.g. for GenericForeignKey
|
||||
@@ -49,11 +49,11 @@ class RestrictedQuerySet(QuerySet):
|
||||
permission_required = get_permission_for_model(self.model, action)
|
||||
|
||||
# Bypass restriction for superusers and exempt views
|
||||
if user.is_superuser or permission_is_exempt(permission_required):
|
||||
if user and user.is_superuser or permission_is_exempt(permission_required):
|
||||
qs = self
|
||||
|
||||
# User is anonymous or has not been granted the requisite permission
|
||||
elif not user.is_authenticated or permission_required not in user.get_all_permissions():
|
||||
elif user is None or not user.is_authenticated or permission_required not in user.get_all_permissions():
|
||||
qs = self.none()
|
||||
|
||||
# Filter the queryset to include only objects with allowed attributes
|
||||
|
||||
76
netbox/utilities/release.py
Normal file
76
netbox/utilities/release.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import datetime
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from utilities.datetime import datetime_from_timestamp
|
||||
|
||||
RELEASE_PATH = 'release.yaml'
|
||||
LOCAL_RELEASE_PATH = 'local/release.yaml'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureSet:
|
||||
"""
|
||||
A map of all available NetBox features.
|
||||
"""
|
||||
# Commercial support is provided by NetBox Labs
|
||||
commercial: bool = False
|
||||
|
||||
# Live help center is enabled
|
||||
help_center: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReleaseInfo:
|
||||
version: str
|
||||
edition: str
|
||||
published: Union[datetime.date, None] = None
|
||||
designation: Union[str, None] = None
|
||||
features: FeatureSet = field(default_factory=FeatureSet)
|
||||
|
||||
@property
|
||||
def full_version(self):
|
||||
if self.designation:
|
||||
return f"{self.version}-{self.designation}"
|
||||
return self.version
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f"NetBox {self.edition} v{self.full_version}"
|
||||
|
||||
def asdict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def load_release_data():
|
||||
"""
|
||||
Load any locally-defined release attributes and return a ReleaseInfo instance.
|
||||
"""
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# Load canonical release attributes
|
||||
with open(os.path.join(base_path, RELEASE_PATH), 'r') as release_file:
|
||||
data = yaml.safe_load(release_file)
|
||||
|
||||
# Overlay any local release date (if defined)
|
||||
try:
|
||||
with open(os.path.join(base_path, LOCAL_RELEASE_PATH), 'r') as release_file:
|
||||
local_data = yaml.safe_load(release_file)
|
||||
except FileNotFoundError:
|
||||
local_data = {}
|
||||
if local_data is not None:
|
||||
if type(local_data) is not dict:
|
||||
raise ImproperlyConfigured(
|
||||
f"{LOCAL_RELEASE_PATH}: Local release data must be defined as a dictionary."
|
||||
)
|
||||
data.update(local_data)
|
||||
|
||||
# Convert the published date to a date object
|
||||
if 'published' in data:
|
||||
data['published'] = datetime_from_timestamp(data['published'])
|
||||
|
||||
return ReleaseInfo(**data)
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core import serializers
|
||||
from mptt.models import MPTTModel
|
||||
|
||||
from extras.utils import is_taggable
|
||||
|
||||
@@ -16,8 +15,7 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
|
||||
"""
|
||||
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
|
||||
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
|
||||
can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
|
||||
implicitly excluded.
|
||||
can be provided to exclude them from the returned dictionary.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize
|
||||
@@ -30,11 +28,6 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
|
||||
data = json.loads(json_str)[0]['fields']
|
||||
exclude = exclude or []
|
||||
|
||||
# Exclude any MPTTModel fields
|
||||
if issubclass(obj.__class__, MPTTModel):
|
||||
for field in ['level', 'lft', 'rght', 'tree_id']:
|
||||
data.pop(field)
|
||||
|
||||
# Include custom_field_data as "custom_fields"
|
||||
if hasattr(obj, 'custom_field_data'):
|
||||
data['custom_fields'] = data.pop('custom_field_data')
|
||||
@@ -45,9 +38,9 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
|
||||
tags = getattr(obj, '_tags', None) or obj.tags.all()
|
||||
data['tags'] = sorted([tag.name for tag in tags])
|
||||
|
||||
# Skip excluded and private (prefixes with an underscore) attributes
|
||||
# Skip any excluded attributes
|
||||
for key in list(data.keys()):
|
||||
if key in exclude or (isinstance(key, str) and key.startswith('_')):
|
||||
if key in exclude:
|
||||
data.pop(key)
|
||||
|
||||
# Append any extra data
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core.serializers.json import Deserializer, Serializer as Serializer_ # noqa
|
||||
from django.core.serializers.json import Deserializer, Serializer as Serializer_ # noqa: F401
|
||||
from django.utils.encoding import is_protected_type
|
||||
|
||||
# NOTE: Module must contain both Serializer and Deserializer
|
||||
|
||||
104
netbox/utilities/socks.py
Normal file
104
netbox/utilities/socks.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from urllib3 import PoolManager, HTTPConnectionPool, HTTPSConnectionPool
|
||||
from urllib3.connection import HTTPConnection, HTTPSConnection
|
||||
from .constants import HTTP_PROXY_SOCK_RDNS_SCHEMAS
|
||||
|
||||
|
||||
logger = logging.getLogger('netbox.utilities')
|
||||
|
||||
|
||||
class ProxyHTTPConnection(HTTPConnection):
|
||||
"""
|
||||
A Proxy connection class that uses a SOCK proxy - used to create
|
||||
a urllib3 PoolManager that routes connections via the proxy.
|
||||
This is for an HTTP (not HTTPS) connection
|
||||
"""
|
||||
use_rdns = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
socks_options = kwargs.pop('_socks_options')
|
||||
self._proxy_url = socks_options['proxy_url']
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _new_conn(self):
|
||||
try:
|
||||
from python_socks.sync import Proxy
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info(
|
||||
"Configuring an HTTP proxy using SOCKS requires the python_socks library. Check that it has been "
|
||||
"installed."
|
||||
)
|
||||
raise e
|
||||
|
||||
proxy = Proxy.from_url(self._proxy_url, rdns=self.use_rdns)
|
||||
return proxy.connect(
|
||||
dest_host=self.host,
|
||||
dest_port=self.port,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
|
||||
class ProxyHTTPSConnection(ProxyHTTPConnection, HTTPSConnection):
|
||||
"""
|
||||
A Proxy connection class for an HTTPS (not HTTP) connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RdnsProxyHTTPConnection(ProxyHTTPConnection):
|
||||
"""
|
||||
A Proxy connection class for an HTTP remote-dns connection.
|
||||
I.E. socks4a, socks4h, socks5a, socks5h
|
||||
"""
|
||||
use_rdns = True
|
||||
|
||||
|
||||
class RdnsProxyHTTPSConnection(ProxyHTTPSConnection):
|
||||
"""
|
||||
A Proxy connection class for an HTTPS remote-dns connection.
|
||||
I.E. socks4a, socks4h, socks5a, socks5h
|
||||
"""
|
||||
use_rdns = True
|
||||
|
||||
|
||||
class ProxyHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = ProxyHTTPConnection
|
||||
|
||||
|
||||
class ProxyHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = ProxyHTTPSConnection
|
||||
|
||||
|
||||
class RdnsProxyHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = RdnsProxyHTTPConnection
|
||||
|
||||
|
||||
class RdnsProxyHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = RdnsProxyHTTPSConnection
|
||||
|
||||
|
||||
class ProxyPoolManager(PoolManager):
|
||||
def __init__(self, proxy_url, timeout=5, num_pools=10, headers=None, **connection_pool_kw):
|
||||
# python_socks uses rdns param to denote remote DNS parsing and
|
||||
# doesn't accept the 'h' or 'a' in the proxy URL
|
||||
if use_rdns := urlparse(proxy_url).scheme in HTTP_PROXY_SOCK_RDNS_SCHEMAS:
|
||||
proxy_url = proxy_url.replace('socks5h:', 'socks5:').replace('socks5a:', 'socks5:')
|
||||
proxy_url = proxy_url.replace('socks4h:', 'socks4:').replace('socks4a:', 'socks4:')
|
||||
|
||||
connection_pool_kw['_socks_options'] = {'proxy_url': proxy_url}
|
||||
connection_pool_kw['timeout'] = timeout
|
||||
|
||||
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||
|
||||
if use_rdns:
|
||||
self.pool_classes_by_scheme = {
|
||||
'http': RdnsProxyHTTPConnectionPool,
|
||||
'https': RdnsProxyHTTPSConnectionPool,
|
||||
}
|
||||
else:
|
||||
self.pool_classes_by_scheme = {
|
||||
'http': ProxyHTTPConnectionPool,
|
||||
'https': ProxyHTTPSConnectionPool,
|
||||
}
|
||||
@@ -29,7 +29,7 @@ def linkify_phone(value):
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return f"tel:{value}"
|
||||
return f"tel:{value.replace(' ', '')}"
|
||||
|
||||
|
||||
def register_table_column(column, name, *tables):
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
{% elif customfield.type == 'date' and value %}
|
||||
{{ value|isodate }}
|
||||
{% elif customfield.type == 'datetime' and value %}
|
||||
{{ value|isodate }} {{ value|isodatetime }}
|
||||
{{ value|isodatetime }}
|
||||
{% elif customfield.type == 'url' and value %}
|
||||
<a href="{{ value }}">{{ value|truncatechars:70 }}</a>
|
||||
{% elif customfield.type == 'json' and value %}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<div class="htmx-container table-responsive"
|
||||
hx-get="{% url viewname %}{% if url_params %}?{{ url_params.urlencode }}{% endif %}"
|
||||
hx-get="{% url viewname %}?embedded=True{% if url_params %}&{{ url_params.urlencode }}{% endif %}"
|
||||
hx-target="this"
|
||||
hx-trigger="load" hx-select="table" hx-swap="innerHTML"
|
||||
hx-trigger="load" hx-select=".htmx-container" hx-swap="outerHTML"
|
||||
></div>
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
</button>
|
||||
{% else %}
|
||||
<button type="submit" class="btn btn-cyan">
|
||||
<i class="mdi mdi-bookmark-check"></i> {% trans "Bookmark" %}
|
||||
<i class="mdi mdi-bookmark-plus"></i> {% trans "Bookmark" %}
|
||||
</button>
|
||||
{% endif %}
|
||||
</form>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<i class="mdi mdi-download"></i> {% trans "Export" %}
|
||||
</button>
|
||||
<ul class="dropdown-menu dropdown-menu-end">
|
||||
<li><a class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export=table">{% trans "Current View" %}</a></li>
|
||||
<li><a id="export_current_view" class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export=table">{% trans "Current View" %}</a></li>
|
||||
<li><a class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export">{% trans "All Data" %} ({{ data_format }})</a></li>
|
||||
{% if export_templates %}
|
||||
<li>
|
||||
|
||||
18
netbox/utilities/templates/buttons/subscribe.html
Normal file
18
netbox/utilities/templates/buttons/subscribe.html
Normal file
@@ -0,0 +1,18 @@
|
||||
{% load i18n %}
|
||||
{% if form_url %}
|
||||
<form action="{{ form_url }}?return_url={{ return_url }}" method="post">
|
||||
{% csrf_token %}
|
||||
{% for field, value in form_data.items %}
|
||||
<input type="hidden" name="{{ field }}" value="{{ value }}" />
|
||||
{% endfor %}
|
||||
{% if subscription %}
|
||||
<button type="submit" class="btn btn-cyan">
|
||||
<i class="mdi mdi-bell-minus"></i> {% trans "Unsubscribe" %}
|
||||
</button>
|
||||
{% else %}
|
||||
<button type="submit" class="btn btn-cyan">
|
||||
<i class="mdi mdi-bell-plus"></i> {% trans "Subscribe" %}
|
||||
</button>
|
||||
{% endif %}
|
||||
</form>
|
||||
{% endif %}
|
||||
@@ -3,7 +3,7 @@
|
||||
{% for group, fields in form.custom_field_groups.items %}
|
||||
{% if group %}
|
||||
<div class="row">
|
||||
<h6 class="offset-sm-3 mb-3">{{ group }}</h6>
|
||||
<h3 class="col-9 offset-3 mb-3 h4">{{ group }}</h3>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% for name in fields %}
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
|
||||
{# Render the field label (if any), except for checkboxes #}
|
||||
{% if label and not field|widget_type == 'checkboxinput' %}
|
||||
<label for="{{ field.id_for_label }}" class="col-sm-3 col-form-label text-lg-end{% if field.field.required %} required{% endif %}">
|
||||
{{ label }}
|
||||
</label>
|
||||
<div class="col-sm-3 text-lg-end">
|
||||
<label for="{{ field.id_for_label }}" class="col-form-label d-inline-block{% if field.field.required %} required{% endif %}">
|
||||
{{ label }}
|
||||
</label>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{# Render the field itself #}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<div class="field-group mb-5">
|
||||
{% if heading %}
|
||||
<div class="row">
|
||||
<h5 class="col-9 offset-3">{{ heading }}</h5>
|
||||
<h2 class="col-9 offset-3">{{ heading }}</h2>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% for layout, title, items in rows %}
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
</div>
|
||||
<div class="col-2 d-flex align-items-center">
|
||||
<div>
|
||||
<a class="btn btn-success btn-sm w-100 my-2" id="add_columns">
|
||||
<a tabindex="0" class="btn btn-success btn-sm w-100 my-2" id="add_columns">
|
||||
<i class="mdi mdi-arrow-right-bold"></i> {% trans "Add" %}
|
||||
</a>
|
||||
<a class="btn btn-danger btn-sm w-100 my-2" id="remove_columns">
|
||||
<a tabindex="0" class="btn btn-danger btn-sm w-100 my-2" id="remove_columns">
|
||||
<i class="mdi mdi-arrow-left-bold"></i> {% trans "Remove" %}
|
||||
</a>
|
||||
</div>
|
||||
@@ -27,10 +27,10 @@
|
||||
<div class="col-5 text-center">
|
||||
{{ form.columns.label }}
|
||||
{{ form.columns }}
|
||||
<a class="btn btn-primary btn-sm mt-2" id="move-option-up" data-target="id_columns">
|
||||
<a tabindex="0" class="btn btn-primary btn-sm mt-2" id="move-option-up" data-target="id_columns">
|
||||
<i class="mdi mdi-arrow-up-bold"></i> {% trans "Move Up" %}
|
||||
</a>
|
||||
<a class="btn btn-primary btn-sm mt-2" id="move-option-down" data-target="id_columns">
|
||||
<a tabindex="0" class="btn btn-primary btn-sm mt-2" id="move-option-down" data-target="id_columns">
|
||||
<i class="mdi mdi-arrow-down-bold"></i> {% trans "Move Down" %}
|
||||
</a>
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,23 @@
|
||||
{% load helpers %}
|
||||
{% load i18n %}
|
||||
{% load navigation %}
|
||||
|
||||
<ul class="navbar-nav pt-lg-2" {% htmx_boost %}>
|
||||
<li class="nav-item d-block d-lg-none">
|
||||
<form action="{% url 'search' %}" method="get" autocomplete="off" novalidate>
|
||||
<div class="input-group mb-1 mt-2">
|
||||
<div class="input-group-prepend">
|
||||
<span class="input-group-text">
|
||||
<i class="mdi mdi-magnify"></i>
|
||||
</span>
|
||||
</div>
|
||||
<input type="text" name="q" value="" class="form-control" placeholder="{% trans "Search…" %}" aria-label="{% trans "Search NetBox" %}">
|
||||
<div class="input-group-append">
|
||||
<button type="submit" class="form-control">{% trans "Search" %}</button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</li>
|
||||
{% for menu, groups in nav_items %}
|
||||
<li class="nav-item dropdown">
|
||||
|
||||
@@ -29,7 +45,7 @@
|
||||
{% if buttons %}
|
||||
<div class="btn-group ms-1">
|
||||
{% for button in buttons %}
|
||||
<a href="{% url button.link %}" class="btn btn-sm btn-{{ button.color|default:"outline-dark" }} lh-2 px-2" title="{{ button.title }}">
|
||||
<a href="{% url button.link %}" class="btn btn-sm btn-{{ button.color|default:"outline" }} lh-2 px-2" title="{{ button.title }}">
|
||||
<i class="{{ button.icon_class }}"></i>
|
||||
</a>
|
||||
{% endfor %}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{% load i18n %}
|
||||
{% if widget.attrs.selector and not widget.attrs.disabled %}
|
||||
<div class="d-flex">
|
||||
{% include 'django/forms/widgets/select.html' %}
|
||||
<div class="d-flex">
|
||||
{% include 'django/forms/widgets/select.html' %}
|
||||
{% if widget.attrs.selector and not widget.attrs.disabled %}
|
||||
{# Opens the object selector modal #}
|
||||
<button
|
||||
type="button"
|
||||
title="{% trans "Open selector" %}"
|
||||
@@ -13,7 +14,19 @@
|
||||
>
|
||||
<i class="mdi mdi-database-search-outline"></i>
|
||||
</button>
|
||||
</div>
|
||||
{% else %}
|
||||
{% include 'django/forms/widgets/select.html' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% if quick_add and not widget.attrs.disabled %}
|
||||
{# Opens the quick add modal #}
|
||||
<button
|
||||
type="button"
|
||||
title="{% trans "Quick add" %}"
|
||||
class="btn btn-outline-secondary ms-1"
|
||||
data-bs-toggle="modal"
|
||||
data-bs-target="#htmx-modal"
|
||||
hx-get="{{ quick_add.url }}?_quickadd=True&target={{ widget.attrs.id }}{% for k, v in quick_add.params.items %}&{{ k }}={{ v }}{% endfor %}"
|
||||
hx-target="#htmx-modal-content"
|
||||
>
|
||||
<i class="mdi mdi-plus-circle"></i>
|
||||
</button>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
@@ -8,6 +8,7 @@ from django.contrib.contenttypes.models import ContentType
|
||||
from django.contrib.humanize.templatetags.humanize import naturalday, naturaltime
|
||||
from django.utils.html import escape
|
||||
from django.utils.safestring import mark_safe
|
||||
from django.utils.timezone import localtime
|
||||
from markdown import markdown
|
||||
from markdown.extensions.tables import TableExtension
|
||||
|
||||
@@ -58,7 +59,7 @@ def linkify(instance, attr=None):
|
||||
url = instance.get_absolute_url()
|
||||
return mark_safe(f'<a href="{url}">{escape(text)}</a>')
|
||||
except (AttributeError, TypeError):
|
||||
return text
|
||||
return escape(text)
|
||||
|
||||
|
||||
@register.filter()
|
||||
@@ -218,7 +219,8 @@ def isodate(value):
|
||||
text = value.isoformat()
|
||||
return mark_safe(f'<span title="{naturalday(value)}">{text}</span>')
|
||||
elif type(value) is datetime.datetime:
|
||||
text = value.date().isoformat()
|
||||
local_value = localtime(value) if value.tzinfo else value
|
||||
text = local_value.date().isoformat()
|
||||
return mark_safe(f'<span title="{naturaltime(value)}">{text}</span>')
|
||||
else:
|
||||
return ''
|
||||
@@ -229,7 +231,8 @@ def isotime(value, spec='seconds'):
|
||||
if type(value) is datetime.time:
|
||||
return value.isoformat(timespec=spec)
|
||||
if type(value) is datetime.datetime:
|
||||
return value.time().isoformat(timespec=spec)
|
||||
local_value = localtime(value) if value.tzinfo else value
|
||||
return local_value.time().isoformat(timespec=spec)
|
||||
return ''
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django import template
|
||||
from django.utils.safestring import mark_safe
|
||||
|
||||
from extras.choices import CustomFieldTypeChoices
|
||||
from utilities.querydict import dict_to_querydict
|
||||
@@ -124,5 +125,5 @@ def formaction(context):
|
||||
if HTMX navigation is enabled (per the user's preferences).
|
||||
"""
|
||||
if context.get('htmx_navigation', False):
|
||||
return 'hx-push-url="true" hx-post'
|
||||
return mark_safe('hx-push-url="true" hx-post')
|
||||
return 'formaction'
|
||||
|
||||
@@ -3,7 +3,8 @@ from django.contrib.contenttypes.models import ContentType
|
||||
from django.urls import NoReverseMatch, reverse
|
||||
|
||||
from core.models import ObjectType
|
||||
from extras.models import Bookmark, ExportTemplate
|
||||
from extras.models import Bookmark, ExportTemplate, Subscription
|
||||
from netbox.models.features import NotificationsMixin
|
||||
from utilities.querydict import prepare_cloned_fields
|
||||
from utilities.views import get_viewname
|
||||
|
||||
@@ -17,6 +18,7 @@ __all__ = (
|
||||
'edit_button',
|
||||
'export_button',
|
||||
'import_button',
|
||||
'subscribe_button',
|
||||
'sync_button',
|
||||
)
|
||||
|
||||
@@ -94,6 +96,41 @@ def delete_button(instance):
|
||||
}
|
||||
|
||||
|
||||
@register.inclusion_tag('buttons/subscribe.html', takes_context=True)
|
||||
def subscribe_button(context, instance):
|
||||
# Skip for objects which don't support notifications
|
||||
if not (issubclass(instance.__class__, NotificationsMixin)):
|
||||
return {}
|
||||
|
||||
# Check if this user has already subscribed to the object
|
||||
content_type = ContentType.objects.get_for_model(instance)
|
||||
subscription = Subscription.objects.filter(
|
||||
object_type=content_type,
|
||||
object_id=instance.pk,
|
||||
user=context['request'].user
|
||||
).first()
|
||||
|
||||
# Compile form URL & data
|
||||
if subscription:
|
||||
form_url = reverse('extras:subscription_delete', kwargs={'pk': subscription.pk})
|
||||
form_data = {
|
||||
'confirm': 'true',
|
||||
}
|
||||
else:
|
||||
form_url = reverse('extras:subscription_add')
|
||||
form_data = {
|
||||
'object_type': content_type.pk,
|
||||
'object_id': instance.pk,
|
||||
}
|
||||
|
||||
return {
|
||||
'subscription': subscription,
|
||||
'form_url': form_url,
|
||||
'form_data': form_data,
|
||||
'return_url': instance.get_absolute_url(),
|
||||
}
|
||||
|
||||
|
||||
@register.inclusion_tag('buttons/sync.html')
|
||||
def sync_button(instance):
|
||||
viewname = get_viewname(instance, 'sync')
|
||||
@@ -121,7 +158,7 @@ def add_button(model, action='add'):
|
||||
|
||||
|
||||
@register.inclusion_tag('buttons/import.html')
|
||||
def import_button(model, action='import'):
|
||||
def import_button(model, action='bulk_import'):
|
||||
try:
|
||||
url = reverse(get_viewname(model, action))
|
||||
except NoReverseMatch:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import warnings
|
||||
|
||||
from django import template
|
||||
|
||||
from utilities.forms.rendering import FieldSet, InlineFields, ObjectAttribute, TabbedGroups
|
||||
from utilities.forms.rendering import InlineFields, ObjectAttribute, TabbedGroups
|
||||
|
||||
__all__ = (
|
||||
'getfield',
|
||||
@@ -54,15 +52,6 @@ def render_fieldset(form, fieldset):
|
||||
"""
|
||||
Render a group set of fields.
|
||||
"""
|
||||
# TODO: Remove in NetBox v4.1
|
||||
# Handle legacy tuple-based fieldset definitions, e.g. (_('Label'), ('field1, 'field2', 'field3'))
|
||||
if type(fieldset) is not FieldSet:
|
||||
warnings.warn(
|
||||
f"{form.__class__} fieldsets contains a non-FieldSet item: {fieldset}"
|
||||
)
|
||||
name, fields = fieldset
|
||||
fieldset = FieldSet(*fields, name=name)
|
||||
|
||||
rows = []
|
||||
for item in fieldset.items:
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from django import template
|
||||
from django.conf import settings
|
||||
from django.template.defaultfilters import date
|
||||
from django.urls import NoReverseMatch, reverse
|
||||
from django.utils import timezone
|
||||
from django.utils.safestring import mark_safe
|
||||
|
||||
from core.models import ObjectType
|
||||
from utilities.forms import get_selected_values, TableConfigForm
|
||||
@@ -92,15 +87,22 @@ def humanize_speed(speed):
|
||||
@register.filter()
|
||||
def humanize_megabytes(mb):
|
||||
"""
|
||||
Express a number of megabytes in the most suitable unit (e.g. gigabytes or terabytes).
|
||||
Express a number of megabytes in the most suitable unit (e.g. gigabytes, terabytes, etc.).
|
||||
"""
|
||||
if not mb:
|
||||
return ''
|
||||
if not mb % 1048576: # 1024^2
|
||||
return f'{int(mb / 1048576)} TB'
|
||||
if not mb % 1024:
|
||||
return f'{int(mb / 1024)} GB'
|
||||
return f'{mb} MB'
|
||||
return ""
|
||||
|
||||
PB_SIZE = 1000000000
|
||||
TB_SIZE = 1000000
|
||||
GB_SIZE = 1000
|
||||
|
||||
if mb >= PB_SIZE:
|
||||
return f"{mb / PB_SIZE:.2f} PB"
|
||||
if mb >= TB_SIZE:
|
||||
return f"{mb / TB_SIZE:.2f} TB"
|
||||
if mb >= GB_SIZE:
|
||||
return f"{mb / GB_SIZE:.2f} GB"
|
||||
return f"{mb} MB"
|
||||
|
||||
|
||||
@register.filter()
|
||||
@@ -279,6 +281,10 @@ def applied_filters(context, model, form, query_params):
|
||||
if filter_name not in querydict:
|
||||
continue
|
||||
|
||||
# Skip saved filters, as they're displayed alongside the quick search widget
|
||||
if filter_name == 'filter_id':
|
||||
continue
|
||||
|
||||
bound_field = form.fields[filter_name].get_bound_field(form, filter_name)
|
||||
querydict.pop(filter_name)
|
||||
display_value = ', '.join([str(v) for v in get_selected_values(form, filter_name)])
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django import template
|
||||
from django.utils.html import escape
|
||||
from django.utils.safestring import mark_safe
|
||||
|
||||
register = template.Library()
|
||||
@@ -15,6 +16,6 @@ def nested_tree(obj):
|
||||
nodes = obj.get_ancestors(include_self=True)
|
||||
return mark_safe(
|
||||
' / '.join(
|
||||
f'<a href="{node.get_absolute_url()}">{node}</a>' for node in nodes
|
||||
f'<a href="{node.get_absolute_url()}">{escape(node)}</a>' for node in nodes
|
||||
)
|
||||
)
|
||||
|
||||
@@ -22,8 +22,10 @@ def _get_registered_content(obj, method, template_context):
|
||||
'perms': template_context['perms'],
|
||||
}
|
||||
|
||||
model_name = obj._meta.label_lower
|
||||
template_extensions = registry['plugins']['template_extensions'].get(model_name, [])
|
||||
template_extensions = list(registry['plugins']['template_extensions'].get(None, []))
|
||||
if hasattr(obj, '_meta'):
|
||||
model_name = obj._meta.label_lower
|
||||
template_extensions.extend(registry['plugins']['template_extensions'].get(model_name, []))
|
||||
for template_extension in template_extensions:
|
||||
|
||||
# If the class has not overridden the specified method, we can skip it (because we know it
|
||||
@@ -43,6 +45,22 @@ def _get_registered_content(obj, method, template_context):
|
||||
return mark_safe(html)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_navbar(context):
|
||||
"""
|
||||
Render any navbar content embedded by plugins
|
||||
"""
|
||||
return _get_registered_content(None, 'navbar', context)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_list_buttons(context, model):
|
||||
"""
|
||||
Render all list buttons registered by plugins
|
||||
"""
|
||||
return _get_registered_content(model, 'list_buttons', context)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_buttons(context, obj):
|
||||
"""
|
||||
@@ -51,6 +69,14 @@ def plugin_buttons(context, obj):
|
||||
return _get_registered_content(obj, 'buttons', context)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_alerts(context, obj):
|
||||
"""
|
||||
Render all object alerts registered by plugins
|
||||
"""
|
||||
return _get_registered_content(obj, 'alerts', context)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_left_page(context, obj):
|
||||
"""
|
||||
@@ -73,11 +99,3 @@ def plugin_full_width_page(context, obj):
|
||||
Render all full width page content registered by plugins
|
||||
"""
|
||||
return _get_registered_content(obj, 'full_width_page', context)
|
||||
|
||||
|
||||
@register.simple_tag(takes_context=True)
|
||||
def plugin_list_buttons(context, model):
|
||||
"""
|
||||
Render all list buttons registered by plugins
|
||||
"""
|
||||
return _get_registered_content(model, 'list_buttons', context)
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import inspect
|
||||
import json
|
||||
import strawberry_django
|
||||
|
||||
import strawberry_django
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.urls import reverse
|
||||
from django.test import override_settings
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
from strawberry.types.base import StrawberryList, StrawberryOptional
|
||||
from strawberry.types.lazy_type import LazyType
|
||||
from strawberry.types.union import StrawberryUnion
|
||||
|
||||
from core.models import ObjectType
|
||||
from extras.choices import ObjectChangeActionChoices
|
||||
from extras.models import ObjectChange
|
||||
from users.models import ObjectPermission, Token
|
||||
from core.choices import ObjectChangeActionChoices
|
||||
from core.models import ObjectChange, ObjectType
|
||||
from ipam.graphql.types import IPAddressFamilyType
|
||||
from users.models import ObjectPermission, Token, User
|
||||
from utilities.api import get_graphql_type_for_model
|
||||
from .base import ModelTestCase
|
||||
from .utils import disable_warnings
|
||||
|
||||
from ipam.graphql.types import IPAddressFamilyType
|
||||
from strawberry.field import StrawberryField
|
||||
from strawberry.lazy_type import LazyType
|
||||
from strawberry.type import StrawberryList, StrawberryOptional
|
||||
from strawberry.union import StrawberryUnion
|
||||
from .utils import disable_logging, disable_warnings
|
||||
|
||||
__all__ = (
|
||||
'APITestCase',
|
||||
@@ -30,9 +26,6 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
#
|
||||
# REST/GraphQL API Tests
|
||||
#
|
||||
@@ -73,7 +66,7 @@ class APIViewTestCases:
|
||||
|
||||
class GetObjectViewTestCase(APITestCase):
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
|
||||
def test_get_object_anonymous(self):
|
||||
"""
|
||||
GET a single object as an unauthenticated user.
|
||||
@@ -135,7 +128,7 @@ class APIViewTestCases:
|
||||
class ListObjectsViewTestCase(APITestCase):
|
||||
brief_fields = []
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
|
||||
def test_list_objects_anonymous(self):
|
||||
"""
|
||||
GET a list of objects as an unauthenticated user.
|
||||
@@ -440,18 +433,20 @@ class APIViewTestCases:
|
||||
base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
|
||||
return getattr(self, 'graphql_base_name', base_name)
|
||||
|
||||
def _build_query(self, name, **filters):
|
||||
def _build_query_with_filter(self, name, filter_string):
|
||||
"""
|
||||
Called by either _build_query or _build_filtered_query - construct the actual
|
||||
query given a name and filter string
|
||||
"""
|
||||
type_class = get_graphql_type_for_model(self.model)
|
||||
if filters:
|
||||
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
|
||||
filter_string = f'({filter_string})'
|
||||
else:
|
||||
filter_string = ''
|
||||
|
||||
# Compile list of fields to include
|
||||
fields_string = ''
|
||||
|
||||
file_fields = (strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType)
|
||||
file_fields = (
|
||||
strawberry_django.fields.types.DjangoFileType,
|
||||
strawberry_django.fields.types.DjangoImageType,
|
||||
)
|
||||
for field in type_class.__strawberry_definition__.fields:
|
||||
if (
|
||||
field.type in file_fields or (
|
||||
@@ -492,8 +487,39 @@ class APIViewTestCases:
|
||||
|
||||
return query
|
||||
|
||||
def _build_filtered_query(self, name, **filters):
|
||||
"""
|
||||
Create a filtered query: i.e. device_list(filters: {name: {i_contains: "akron"}}){.
|
||||
"""
|
||||
# TODO: This should be extended to support AND, OR multi-lookups
|
||||
if filters:
|
||||
for field_name, params in filters.items():
|
||||
lookup = params['lookup']
|
||||
value = params['value']
|
||||
if lookup:
|
||||
query = f'{{{lookup}: "{value}"}}'
|
||||
filter_string = f'{field_name}: {query}'
|
||||
else:
|
||||
filter_string = f'{field_name}: "{value}"'
|
||||
filter_string = f'(filters: {{{filter_string}}})'
|
||||
else:
|
||||
filter_string = ''
|
||||
|
||||
return self._build_query_with_filter(name, filter_string)
|
||||
|
||||
def _build_query(self, name, **filters):
|
||||
"""
|
||||
Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
|
||||
"""
|
||||
if filters:
|
||||
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
|
||||
filter_string = f'({filter_string})'
|
||||
else:
|
||||
filter_string = ''
|
||||
|
||||
return self._build_query_with_filter(name, filter_string)
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
|
||||
def test_graphql_get_object(self):
|
||||
url = reverse('graphql')
|
||||
field_name = self._get_graphql_base_name()
|
||||
@@ -501,39 +527,92 @@ class APIViewTestCases:
|
||||
query = self._build_query(field_name, id=object_id)
|
||||
|
||||
# Non-authenticated requests should fail
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
with disable_warnings('django.request'):
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
|
||||
response = self.client.post(url, data={'query': query}, format="json", **header)
|
||||
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add object-level permission
|
||||
# Add constrained permission
|
||||
obj_perm = ObjectPermission(
|
||||
name='Test permission',
|
||||
actions=['view']
|
||||
actions=['view'],
|
||||
constraints={'id': 0} # Impossible constraint
|
||||
)
|
||||
obj_perm.save()
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
# Request should succeed but return empty result
|
||||
with disable_logging():
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertIn('errors', data)
|
||||
self.assertIsNone(data['data'])
|
||||
|
||||
# Remove permission constraint
|
||||
obj_perm.constraints = None
|
||||
obj_perm.save()
|
||||
|
||||
# Request should return requested object
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertIsNotNone(data['data'])
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
|
||||
def test_graphql_list_objects(self):
|
||||
url = reverse('graphql')
|
||||
field_name = f'{self._get_graphql_base_name()}_list'
|
||||
query = self._build_query(field_name)
|
||||
|
||||
# Non-authenticated requests should fail
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
with disable_warnings('django.request'):
|
||||
header = {
|
||||
'HTTP_ACCEPT': 'application/json',
|
||||
}
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
|
||||
response = self.client.post(url, data={'query': query}, format="json", **header)
|
||||
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add constrained permission
|
||||
obj_perm = ObjectPermission(
|
||||
name='Test permission',
|
||||
actions=['view'],
|
||||
constraints={'id': 0} # Impossible constraint
|
||||
)
|
||||
obj_perm.save()
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
# Request should succeed but return empty results list
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data'][field_name]), 0)
|
||||
|
||||
# Remove permission constraint
|
||||
obj_perm.constraints = None
|
||||
obj_perm.save()
|
||||
|
||||
# Request should return all objects
|
||||
response = self.client.post(url, data={'query': query}, format="json", **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertEqual(len(data['data'][field_name]), self.model.objects.count())
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
def test_graphql_filter_objects(self):
|
||||
if not hasattr(self, 'graphql_filter'):
|
||||
return
|
||||
|
||||
url = reverse('graphql')
|
||||
field_name = f'{self._get_graphql_base_name()}_list'
|
||||
query = self._build_filtered_query(field_name, **self.graphql_filter)
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.fields import ArrayField, RangeField
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db.models import ManyToManyField, ManyToManyRel, JSONField
|
||||
from django.forms.models import model_to_dict
|
||||
@@ -11,7 +11,8 @@ from netaddr import IPNetwork
|
||||
from taggit.managers import TaggableManager
|
||||
|
||||
from core.models import ObjectType
|
||||
from users.models import ObjectPermission
|
||||
from users.models import ObjectPermission, User
|
||||
from utilities.data import ranges_to_string
|
||||
from utilities.object_types import object_type_identifier
|
||||
from utilities.permissions import resolve_permission_type
|
||||
from .utils import DUMMY_CF_DATA, extract_form_failures
|
||||
@@ -28,7 +29,7 @@ class TestCase(_TestCase):
|
||||
def setUp(self):
|
||||
|
||||
# Create the test user and assign permissions
|
||||
self.user = get_user_model().objects.create_user(username='testuser')
|
||||
self.user = User.objects.create_user(username='testuser')
|
||||
self.add_permissions(*self.user_permissions)
|
||||
|
||||
# Initialize the test client
|
||||
@@ -120,6 +121,10 @@ class ModelTestCase(TestCase):
|
||||
else:
|
||||
model_dict[key] = sorted([obj.pk for obj in value])
|
||||
|
||||
# Handle GenericForeignKeys
|
||||
elif value and type(field) is GenericForeignKey:
|
||||
model_dict[key] = value.pk
|
||||
|
||||
elif api:
|
||||
|
||||
# Replace ContentType numeric IDs with <app_label>.<model>
|
||||
@@ -136,9 +141,15 @@ class ModelTestCase(TestCase):
|
||||
|
||||
# Convert ArrayFields to CSV strings
|
||||
if type(field) is ArrayField:
|
||||
if type(field.base_field) is ArrayField:
|
||||
if getattr(field.base_field, 'choices', None):
|
||||
# Values for fields with pre-defined choices can be returned as lists
|
||||
model_dict[key] = value
|
||||
elif type(field.base_field) is ArrayField:
|
||||
# Handle nested arrays (e.g. choice sets)
|
||||
model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
|
||||
elif issubclass(type(field.base_field), RangeField):
|
||||
# Handle arrays of numeric ranges (e.g. VLANGroup VLAN ID ranges)
|
||||
model_dict[key] = ranges_to_string(value)
|
||||
else:
|
||||
model_dict[key] = ','.join([str(v) for v in value])
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.utils.text import slugify
|
||||
|
||||
@@ -11,6 +10,7 @@ from core.models import ObjectType
|
||||
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site
|
||||
from extras.choices import CustomFieldTypeChoices
|
||||
from extras.models import CustomField, Tag
|
||||
from users.models import User
|
||||
from virtualization.models import Cluster, ClusterType, VirtualMachine
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def create_test_user(username='testuser', permissions=None):
|
||||
"""
|
||||
Create a User with the given permissions.
|
||||
"""
|
||||
user = get_user_model().objects.create_user(username=username)
|
||||
user = User.objects.create_user(username=username)
|
||||
if permissions is None:
|
||||
permissions = ()
|
||||
for perm_name in permissions:
|
||||
@@ -107,6 +107,16 @@ def disable_warnings(logger_name):
|
||||
logger.setLevel(current_level)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_logging(level=logging.CRITICAL):
|
||||
"""
|
||||
Temporarily suppress log messages at or below the specified level (default: critical).
|
||||
"""
|
||||
logging.disable(level)
|
||||
yield
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
|
||||
#
|
||||
# Custom field testing
|
||||
#
|
||||
|
||||
@@ -8,9 +8,8 @@ from django.test import override_settings
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from core.models import ObjectType
|
||||
from extras.choices import ObjectChangeActionChoices
|
||||
from extras.models import ObjectChange
|
||||
from core.choices import ObjectChangeActionChoices
|
||||
from core.models import ObjectChange, ObjectType
|
||||
from netbox.choices import CSVDelimiterChoices, ImportFormatChoices
|
||||
from netbox.models.features import ChangeLoggingMixin, CustomFieldsMixin
|
||||
from users.models import ObjectPermission
|
||||
@@ -62,7 +61,7 @@ class ViewTestCases:
|
||||
"""
|
||||
Retrieve a single instance.
|
||||
"""
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
|
||||
def test_get_object_anonymous(self):
|
||||
# Make the request as an unauthenticated user
|
||||
self.client.logout()
|
||||
@@ -421,7 +420,7 @@ class ViewTestCases:
|
||||
"""
|
||||
Retrieve multiple instances.
|
||||
"""
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
|
||||
def test_list_objects_anonymous(self):
|
||||
# Make the request as an unauthenticated user
|
||||
self.client.logout()
|
||||
@@ -595,10 +594,10 @@ class ViewTestCases:
|
||||
|
||||
# Test GET without permission
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.get(self._get_url('import')), 403)
|
||||
self.assertHttpStatus(self.client.get(self._get_url('bulk_import')), 403)
|
||||
|
||||
# Try POST without permission
|
||||
response = self.client.post(self._get_url('import'), data)
|
||||
response = self.client.post(self._get_url('bulk_import'), data)
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(response, 403)
|
||||
|
||||
@@ -621,10 +620,10 @@ class ViewTestCases:
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
# Try GET with model-level permission
|
||||
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)
|
||||
self.assertHttpStatus(self.client.get(self._get_url('bulk_import')), 200)
|
||||
|
||||
# Test POST with permission
|
||||
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
|
||||
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
|
||||
self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
@@ -650,7 +649,7 @@ class ViewTestCases:
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
# Test POST with permission
|
||||
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
|
||||
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
|
||||
self.assertEqual(initial_count, self._get_queryset().count())
|
||||
|
||||
reader = csv.DictReader(array, delimiter=',')
|
||||
@@ -685,7 +684,7 @@ class ViewTestCases:
|
||||
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
|
||||
|
||||
# Attempt to import non-permitted objects
|
||||
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
|
||||
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 200)
|
||||
self.assertEqual(self._get_queryset().count(), initial_count)
|
||||
|
||||
# Update permission constraints
|
||||
@@ -693,7 +692,7 @@ class ViewTestCases:
|
||||
obj_perm.save()
|
||||
|
||||
# Import permitted objects
|
||||
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
|
||||
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
|
||||
self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1)
|
||||
|
||||
class BulkEditObjectsViewTestCase(ModelViewTestCase):
|
||||
|
||||
@@ -144,24 +144,37 @@ class APIPaginationTestCase(APITestCase):
|
||||
self.assertIsNone(response.data['previous'])
|
||||
self.assertEqual(len(response.data['results']), page_size)
|
||||
|
||||
@override_settings(MAX_PAGE_SIZE=30)
|
||||
def test_default_page_size_with_small_max_page_size(self):
|
||||
response = self.client.get(self.url, format='json', **self.header)
|
||||
page_size = get_config().MAX_PAGE_SIZE
|
||||
paginate_count = get_config().PAGINATE_COUNT
|
||||
self.assertLess(page_size, 100, "Default page size not sufficient for data set")
|
||||
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['count'], 100)
|
||||
self.assertTrue(response.data['next'].endswith(f'?limit={paginate_count}&offset={paginate_count}'))
|
||||
self.assertIsNone(response.data['previous'])
|
||||
self.assertEqual(len(response.data['results']), paginate_count)
|
||||
|
||||
def test_custom_page_size(self):
|
||||
response = self.client.get(f'{self.url}?limit=10', format='json', **self.header)
|
||||
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['count'], 100)
|
||||
self.assertTrue(response.data['next'].endswith(f'?limit=10&offset=10'))
|
||||
self.assertTrue(response.data['next'].endswith('?limit=10&offset=10'))
|
||||
self.assertIsNone(response.data['previous'])
|
||||
self.assertEqual(len(response.data['results']), 10)
|
||||
|
||||
@override_settings(MAX_PAGE_SIZE=20)
|
||||
@override_settings(MAX_PAGE_SIZE=80)
|
||||
def test_max_page_size(self):
|
||||
response = self.client.get(f'{self.url}?limit=0', format='json', **self.header)
|
||||
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['count'], 100)
|
||||
self.assertTrue(response.data['next'].endswith(f'?limit=20&offset=20'))
|
||||
self.assertTrue(response.data['next'].endswith('?limit=80&offset=80'))
|
||||
self.assertIsNone(response.data['previous'])
|
||||
self.assertEqual(len(response.data['results']), 20)
|
||||
self.assertEqual(len(response.data['results']), 80)
|
||||
|
||||
@override_settings(MAX_PAGE_SIZE=0)
|
||||
def test_max_page_size_disabled(self):
|
||||
|
||||
@@ -83,9 +83,9 @@ class CountersTest(TestCase):
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
|
||||
def test_mptt_child_delete(self):
|
||||
device1, device2 = Device.objects.all()
|
||||
device1 = Device.objects.first()
|
||||
inventory_item1 = InventoryItem.objects.create(device=device1, name='Inventory Item 1')
|
||||
inventory_item2 = InventoryItem.objects.create(device=device1, name='Inventory Item 2', parent=inventory_item1)
|
||||
InventoryItem.objects.create(device=device1, name='Inventory Item 2', parent=inventory_item1)
|
||||
device1.refresh_from_db()
|
||||
self.assertEqual(device1.inventory_item_count, 2)
|
||||
|
||||
|
||||
68
netbox/utilities/tests/test_data.py
Normal file
68
netbox/utilities/tests/test_data.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from django.db.backends.postgresql.psycopg_any import NumericRange
|
||||
from django.test import TestCase
|
||||
|
||||
from utilities.data import check_ranges_overlap, ranges_to_string, string_to_ranges
|
||||
|
||||
|
||||
class RangeFunctionsTestCase(TestCase):
|
||||
|
||||
def test_check_ranges_overlap(self):
|
||||
# Non-overlapping ranges
|
||||
self.assertFalse(
|
||||
check_ranges_overlap([
|
||||
NumericRange(9, 19, bounds='(]'), # 10-19
|
||||
NumericRange(19, 30, bounds='(]'), # 20-29
|
||||
])
|
||||
)
|
||||
self.assertFalse(
|
||||
check_ranges_overlap([
|
||||
NumericRange(10, 19, bounds='[]'), # 10-19
|
||||
NumericRange(20, 29, bounds='[]'), # 20-29
|
||||
])
|
||||
)
|
||||
self.assertFalse(
|
||||
check_ranges_overlap([
|
||||
NumericRange(10, 20, bounds='[)'), # 10-19
|
||||
NumericRange(20, 30, bounds='[)'), # 20-29
|
||||
])
|
||||
)
|
||||
|
||||
# Overlapping ranges
|
||||
self.assertTrue(
|
||||
check_ranges_overlap([
|
||||
NumericRange(9, 20, bounds='(]'), # 10-20
|
||||
NumericRange(19, 30, bounds='(]'), # 20-30
|
||||
])
|
||||
)
|
||||
self.assertTrue(
|
||||
check_ranges_overlap([
|
||||
NumericRange(10, 20, bounds='[]'), # 10-20
|
||||
NumericRange(20, 30, bounds='[]'), # 20-30
|
||||
])
|
||||
)
|
||||
self.assertTrue(
|
||||
check_ranges_overlap([
|
||||
NumericRange(10, 21, bounds='[)'), # 10-20
|
||||
NumericRange(20, 31, bounds='[)'), # 10-30
|
||||
])
|
||||
)
|
||||
|
||||
def test_ranges_to_string(self):
|
||||
self.assertEqual(
|
||||
ranges_to_string([
|
||||
NumericRange(10, 20), # 10-19
|
||||
NumericRange(30, 40), # 30-39
|
||||
NumericRange(100, 200), # 100-199
|
||||
]),
|
||||
'10-19,30-39,100-199'
|
||||
)
|
||||
|
||||
def test_string_to_ranges(self):
|
||||
self.assertEqual(
|
||||
string_to_ranges('10-19, 30-39, 100-199'),
|
||||
[
|
||||
NumericRange(10, 19, bounds='[]'), # 10-19
|
||||
NumericRange(30, 39, bounds='[]'), # 30-39
|
||||
NumericRange(100, 199, bounds='[]'), # 100-199
|
||||
]
|
||||
)
|
||||
@@ -7,15 +7,16 @@ from taggit.managers import TaggableManager
|
||||
|
||||
from dcim.choices import *
|
||||
from dcim.fields import MACAddressField
|
||||
from dcim.filtersets import DeviceFilterSet, SiteFilterSet
|
||||
from dcim.filtersets import DeviceFilterSet, SiteFilterSet, InterfaceFilterSet
|
||||
from dcim.models import (
|
||||
Device, DeviceRole, DeviceType, Interface, Manufacturer, Platform, Rack, Region, Site
|
||||
Device, DeviceRole, DeviceType, Interface, MACAddress, Manufacturer, Platform, Rack, Region, Site
|
||||
)
|
||||
from extras.filters import TagFilter
|
||||
from extras.models import TaggedItem
|
||||
from ipam.filtersets import ASNFilterSet
|
||||
from ipam.models import RIR, ASN
|
||||
from netbox.filtersets import BaseFilterSet
|
||||
from wireless.choices import WirelessRoleChoices
|
||||
from utilities.filters import (
|
||||
MultiValueCharFilter, MultiValueDateFilter, MultiValueDateTimeFilter, MultiValueMACAddressFilter,
|
||||
MultiValueNumberFilter, MultiValueTimeFilter, TreeNodeMultipleChoiceFilter,
|
||||
@@ -408,9 +409,9 @@ class DynamicFilterLookupExpressionTest(TestCase):
|
||||
region.save()
|
||||
|
||||
sites = (
|
||||
Site(name='Site 1', slug='abc-site-1', region=regions[0]),
|
||||
Site(name='Site 2', slug='def-site-2', region=regions[1]),
|
||||
Site(name='Site 3', slug='ghi-site-3', region=regions[2]),
|
||||
Site(name='Site 1', slug='abc-site-1', region=regions[0], status=SiteStatusChoices.STATUS_ACTIVE),
|
||||
Site(name='Site 2', slug='def-site-2', region=regions[1], status=SiteStatusChoices.STATUS_ACTIVE),
|
||||
Site(name='Site 3', slug='ghi-site-3', region=regions[2], status=SiteStatusChoices.STATUS_PLANNED),
|
||||
)
|
||||
Site.objects.bulk_create(sites)
|
||||
|
||||
@@ -426,26 +427,88 @@ class DynamicFilterLookupExpressionTest(TestCase):
|
||||
Rack.objects.bulk_create(racks)
|
||||
|
||||
devices = (
|
||||
Device(name='Device 1', device_type=device_types[0], role=roles[0], platform=platforms[0], serial='ABC', asset_tag='1001', site=sites[0], rack=racks[0], position=1, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_ACTIVE, local_context_data={"foo": 123}),
|
||||
Device(name='Device 2', device_type=device_types[1], role=roles[1], platform=platforms[1], serial='DEF', asset_tag='1002', site=sites[1], rack=racks[1], position=2, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_STAGED),
|
||||
Device(name='Device 3', device_type=device_types[2], role=roles[2], platform=platforms[2], serial='GHI', asset_tag='1003', site=sites[2], rack=racks[2], position=3, face=DeviceFaceChoices.FACE_REAR, status=DeviceStatusChoices.STATUS_FAILED),
|
||||
Device(
|
||||
name='Device 1',
|
||||
device_type=device_types[0],
|
||||
role=roles[0],
|
||||
platform=platforms[0],
|
||||
serial='ABC',
|
||||
asset_tag='1001',
|
||||
site=sites[0],
|
||||
rack=racks[0],
|
||||
position=1,
|
||||
face=DeviceFaceChoices.FACE_FRONT,
|
||||
status=DeviceStatusChoices.STATUS_ACTIVE,
|
||||
local_context_data={'foo': 123},
|
||||
),
|
||||
Device(
|
||||
name='Device 2',
|
||||
device_type=device_types[1],
|
||||
role=roles[1],
|
||||
platform=platforms[1],
|
||||
serial='DEF',
|
||||
asset_tag='1002',
|
||||
site=sites[1],
|
||||
rack=racks[1],
|
||||
position=2,
|
||||
face=DeviceFaceChoices.FACE_FRONT,
|
||||
status=DeviceStatusChoices.STATUS_STAGED,
|
||||
),
|
||||
Device(
|
||||
name='Device 3',
|
||||
device_type=device_types[2],
|
||||
role=roles[2],
|
||||
platform=platforms[2],
|
||||
serial='GHI',
|
||||
asset_tag='1003',
|
||||
site=sites[2],
|
||||
rack=racks[2],
|
||||
position=3,
|
||||
face=DeviceFaceChoices.FACE_REAR,
|
||||
status=DeviceStatusChoices.STATUS_FAILED,
|
||||
),
|
||||
)
|
||||
Device.objects.bulk_create(devices)
|
||||
|
||||
mac_addresses = (
|
||||
MACAddress(mac_address='00-00-00-00-00-01'),
|
||||
MACAddress(mac_address='aa-00-00-00-00-01'),
|
||||
MACAddress(mac_address='00-00-00-00-00-02'),
|
||||
MACAddress(mac_address='bb-00-00-00-00-02'),
|
||||
MACAddress(mac_address='00-00-00-00-00-03'),
|
||||
MACAddress(mac_address='cc-00-00-00-00-03'),
|
||||
)
|
||||
MACAddress.objects.bulk_create(mac_addresses)
|
||||
|
||||
interfaces = (
|
||||
Interface(device=devices[0], name='Interface 1', mac_address='00-00-00-00-00-01'),
|
||||
Interface(device=devices[0], name='Interface 2', mac_address='aa-00-00-00-00-01'),
|
||||
Interface(device=devices[1], name='Interface 3', mac_address='00-00-00-00-00-02'),
|
||||
Interface(device=devices[1], name='Interface 4', mac_address='bb-00-00-00-00-02'),
|
||||
Interface(device=devices[2], name='Interface 5', mac_address='00-00-00-00-00-03'),
|
||||
Interface(device=devices[2], name='Interface 6', mac_address='cc-00-00-00-00-03'),
|
||||
Interface(device=devices[0], name='Interface 1'),
|
||||
Interface(device=devices[0], name='Interface 2'),
|
||||
Interface(device=devices[1], name='Interface 3'),
|
||||
Interface(device=devices[1], name='Interface 4'),
|
||||
Interface(device=devices[2], name='Interface 5'),
|
||||
Interface(device=devices[2], name='Interface 6', rf_role=WirelessRoleChoices.ROLE_AP),
|
||||
)
|
||||
Interface.objects.bulk_create(interfaces)
|
||||
|
||||
interfaces[0].mac_addresses.set([mac_addresses[0]])
|
||||
interfaces[1].mac_addresses.set([mac_addresses[1]])
|
||||
interfaces[2].mac_addresses.set([mac_addresses[2]])
|
||||
interfaces[3].mac_addresses.set([mac_addresses[3]])
|
||||
interfaces[4].mac_addresses.set([mac_addresses[4]])
|
||||
interfaces[5].mac_addresses.set([mac_addresses[5]])
|
||||
|
||||
def test_site_name_negation(self):
|
||||
params = {'name__n': ['Site 1']}
|
||||
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 2)
|
||||
|
||||
def test_site_status_icontains(self):
|
||||
params = {'status__ic': [SiteStatusChoices.STATUS_ACTIVE]}
|
||||
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 2)
|
||||
|
||||
def test_site_status_icontains_negation(self):
|
||||
params = {'status__nic': [SiteStatusChoices.STATUS_ACTIVE]}
|
||||
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 1)
|
||||
|
||||
def test_site_slug_icontains(self):
|
||||
params = {'slug__ic': ['-1']}
|
||||
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 1)
|
||||
@@ -553,3 +616,9 @@ class DynamicFilterLookupExpressionTest(TestCase):
|
||||
def test_device_mac_address_icontains_negation(self):
|
||||
params = {'mac_address__nic': ['aa:', 'bb']}
|
||||
self.assertEqual(DeviceFilterSet(params, Device.objects.all()).qs.count(), 1)
|
||||
|
||||
def test_interface_rf_role_empty(self):
|
||||
params = {'rf_role__empty': 'true'}
|
||||
self.assertEqual(InterfaceFilterSet(params, Interface.objects.all()).qs.count(), 5)
|
||||
params = {'rf_role__empty': 'false'}
|
||||
self.assertEqual(InterfaceFilterSet(params, Interface.objects.all()).qs.count(), 1)
|
||||
|
||||
@@ -9,22 +9,27 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
def get_model_urls(app_label, model_name):
|
||||
def get_model_urls(app_label, model_name, detail=True):
|
||||
"""
|
||||
Return a list of URL paths for detail views registered to the given model.
|
||||
|
||||
Args:
|
||||
app_label: App/plugin name
|
||||
model_name: Model name
|
||||
detail: If True (default), return only URL views for an individual object.
|
||||
Otherwise, return only list views.
|
||||
"""
|
||||
paths = []
|
||||
|
||||
# Retrieve registered views for this model
|
||||
try:
|
||||
views = registry['views'][app_label][model_name]
|
||||
views = [
|
||||
view for view in registry['views'][app_label][model_name]
|
||||
if view['detail'] == detail
|
||||
]
|
||||
except KeyError:
|
||||
# No views have been registered for this model
|
||||
views = []
|
||||
return []
|
||||
|
||||
for config in views:
|
||||
# Import the view class or function
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
from typing import Iterable
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.mixins import AccessMixin
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.urls import reverse
|
||||
from django.urls.exceptions import NoReverseMatch
|
||||
from django.utils.http import url_has_allowed_host_and_scheme
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from netbox.plugins import PluginConfig
|
||||
from netbox.registry import registry
|
||||
from utilities.relations import get_related_models
|
||||
from .permissions import resolve_permission
|
||||
|
||||
__all__ = (
|
||||
'ConditionalLoginRequiredMixin',
|
||||
'ContentTypePermissionRequiredMixin',
|
||||
'GetRelatedModelsMixin',
|
||||
'GetReturnURLMixin',
|
||||
'ObjectPermissionRequiredMixin',
|
||||
'ViewTab',
|
||||
@@ -22,10 +29,20 @@ __all__ = (
|
||||
# View Mixins
|
||||
#
|
||||
|
||||
class ContentTypePermissionRequiredMixin(AccessMixin):
|
||||
class ConditionalLoginRequiredMixin(AccessMixin):
|
||||
"""
|
||||
Similar to Django's LoginRequiredMixin, but enforces authentication only if LOGIN_REQUIRED is True.
|
||||
"""
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
|
||||
return self.handle_no_permission()
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ContentTypePermissionRequiredMixin(ConditionalLoginRequiredMixin):
|
||||
"""
|
||||
Similar to Django's built-in PermissionRequiredMixin, but extended to check model-level permission assignments.
|
||||
This is related to ObjectPermissionRequiredMixin, except that is does not enforce object-level permissions,
|
||||
This is related to ObjectPermissionRequiredMixin, except that it does not enforce object-level permissions,
|
||||
and fits within NetBox's custom permission enforcement system.
|
||||
|
||||
additional_permissions: An optional iterable of statically declared permissions to evaluate in addition to those
|
||||
@@ -58,7 +75,7 @@ class ContentTypePermissionRequiredMixin(AccessMixin):
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class ObjectPermissionRequiredMixin(AccessMixin):
|
||||
class ObjectPermissionRequiredMixin(ConditionalLoginRequiredMixin):
|
||||
"""
|
||||
Similar to Django's built-in PermissionRequiredMixin, but extended to check for both model-level and object-level
|
||||
permission assignments. If the user has only object-level permissions assigned, the view's queryset is filtered
|
||||
@@ -119,7 +136,7 @@ class GetReturnURLMixin:
|
||||
# First, see if `return_url` was specified as a query parameter or form data. Use this URL only if it's
|
||||
# considered safe.
|
||||
return_url = request.GET.get('return_url') or request.POST.get('return_url')
|
||||
if return_url and return_url.startswith('/'):
|
||||
if return_url and url_has_allowed_host_and_scheme(return_url, allowed_hosts=None):
|
||||
return return_url
|
||||
|
||||
# Next, check if the object being modified (if any) has an absolute URL.
|
||||
@@ -142,6 +159,46 @@ class GetReturnURLMixin:
|
||||
return reverse('home')
|
||||
|
||||
|
||||
class GetRelatedModelsMixin:
|
||||
"""
|
||||
Provides logic for collecting all related models for the currently viewed model.
|
||||
"""
|
||||
|
||||
def get_related_models(self, request, instance, omit=[], extra=[]):
|
||||
"""
|
||||
Get related models of the view's `queryset` model without those listed in `omit`. Will be sorted alphabetical.
|
||||
|
||||
Args:
|
||||
request: Current request being processed.
|
||||
instance: The instance related models should be looked up for. A list of instances can be passed to match
|
||||
related objects in this list (e.g. to find sites of a region including child regions).
|
||||
omit: Remove relationships to these models from the result. Needs to be passed, if related models don't
|
||||
provide a `_list` view.
|
||||
extra: Add extra models to the list of automatically determined related models. Can be used to add indirect
|
||||
relationships.
|
||||
"""
|
||||
model = self.queryset.model
|
||||
related = filter(
|
||||
lambda m: m[0] is not model and m[0] not in omit,
|
||||
get_related_models(model, False)
|
||||
)
|
||||
|
||||
related_models = [
|
||||
(
|
||||
model.objects.restrict(request.user, 'view').filter(**(
|
||||
{f'{field}__in': instance}
|
||||
if isinstance(instance, Iterable)
|
||||
else {field: instance}
|
||||
)),
|
||||
f'{field}_id'
|
||||
)
|
||||
for model, field in related
|
||||
]
|
||||
related_models.extend(extra)
|
||||
|
||||
return sorted(related_models, key=lambda x: x[0].model._meta.verbose_name.lower())
|
||||
|
||||
|
||||
class ViewTab:
|
||||
"""
|
||||
ViewTabs are used for navigation among multiple object-specific views, such as the changelog or journal for
|
||||
@@ -215,7 +272,7 @@ def get_viewname(model, action=None, rest_api=False):
|
||||
return viewname
|
||||
|
||||
|
||||
def register_model_view(model, name='', path=None, kwargs=None):
|
||||
def register_model_view(model, name='', path=None, detail=True, kwargs=None):
|
||||
"""
|
||||
This decorator can be used to "attach" a view to any model in NetBox. This is typically used to inject
|
||||
additional tabs within a model's detail view. For example, to add a custom tab to NetBox's dcim.Site model:
|
||||
@@ -232,6 +289,7 @@ def register_model_view(model, name='', path=None, kwargs=None):
|
||||
name: The string used to form the view's name for URL resolution (e.g. via `reverse()`). This will be appended
|
||||
to the name of the base view for the model using an underscore. If blank, the model name will be used.
|
||||
path: The URL path by which the view can be reached (optional). If not provided, `name` will be used.
|
||||
detail: True if the path applied to an individual object; False if it attaches to the base (list) path.
|
||||
kwargs: A dictionary of keyword arguments for the view to include when registering its URL path (optional).
|
||||
"""
|
||||
def _wrapper(cls):
|
||||
@@ -244,7 +302,8 @@ def register_model_view(model, name='', path=None, kwargs=None):
|
||||
registry['views'][app_label][model_name].append({
|
||||
'name': name,
|
||||
'view': cls,
|
||||
'path': path or name,
|
||||
'path': path if path is not None else name,
|
||||
'detail': detail,
|
||||
'kwargs': kwargs or {},
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user