Merge branch 'develop' into fix/generic_prefetch_4.2

This commit is contained in:
Andrey Tikhonov
2025-03-06 16:05:25 +01:00
1032 changed files with 253959 additions and 331699 deletions

View File

@@ -129,7 +129,7 @@ def get_annotations_for_serializer(serializer_class, fields_to_include=None):
for field_name, field in serializer_class._declared_fields.items():
if field_name in fields_to_include and type(field) is RelatedObjectCountField:
related_field = model._meta.get_field(field.relation).field
related_field = getattr(model, field.relation).field
annotations[field_name] = count_related(related_field.model, related_field.name)
return annotations

View File

@@ -93,3 +93,7 @@ HTML_ALLOWED_ATTRIBUTES = {
"td": {"align"},
"th": {"align"},
}
HTTP_PROXY_SUPPORTED_SOCK_SCHEMAS = ['socks4', 'socks4a', 'socks4h', 'socks5', 'socks5a', 'socks5h']
HTTP_PROXY_SOCK_RDNS_SCHEMAS = ['socks4h', 'socks4a', 'socks5h', 'socks5a']
HTTP_PROXY_SUPPORTED_SCHEMAS = ['http', 'https', 'socks4', 'socks4a', 'socks4h', 'socks5', 'socks5a', 'socks5h']

View File

@@ -2,7 +2,8 @@ from decimal import Decimal
from django.utils.translation import gettext as _
from dcim.choices import CableLengthUnitChoices, WeightUnitChoices
from dcim.choices import CableLengthUnitChoices
from netbox.choices import WeightUnitChoices
__all__ = (
'to_grams',
@@ -10,9 +11,9 @@ __all__ = (
)
def to_grams(weight, unit):
def to_grams(weight, unit) -> int:
"""
Convert the given weight to kilograms.
Convert the given weight to integer grams.
"""
try:
if weight < 0:
@@ -21,13 +22,13 @@ def to_grams(weight, unit):
raise TypeError(_("Invalid value '{weight}' for weight (must be a number)").format(weight=weight))
if unit == WeightUnitChoices.UNIT_KILOGRAM:
return weight * 1000
return int(weight * 1000)
if unit == WeightUnitChoices.UNIT_GRAM:
return weight
return int(weight)
if unit == WeightUnitChoices.UNIT_POUND:
return weight * Decimal(453.592)
return int(weight * Decimal(453.592))
if unit == WeightUnitChoices.UNIT_OUNCE:
return weight * Decimal(28.3495)
return int(weight * Decimal(28.3495))
raise ValueError(
_("Unknown unit {unit}. Must be one of the following: {valid_units}").format(
unit=unit,

View File

@@ -1,13 +1,17 @@
import decimal
from django.db.backends.postgresql.psycopg_any import NumericRange
from itertools import count, groupby
__all__ = (
'array_to_ranges',
'array_to_string',
'check_ranges_overlap',
'deepmerge',
'drange',
'flatten_dict',
'ranges_to_string',
'shallow_compare_dict',
'string_to_ranges',
)
@@ -113,3 +117,52 @@ def drange(start, end, step=decimal.Decimal(1)):
while start > end:
yield start
start += step
def check_ranges_overlap(ranges):
"""
Check for overlap in an iterable of NumericRanges.
"""
ranges.sort(key=lambda x: x.lower)
for i in range(1, len(ranges)):
prev_range = ranges[i - 1]
prev_upper = prev_range.upper if prev_range.upper_inc else prev_range.upper - 1
lower = ranges[i].lower if ranges[i].lower_inc else ranges[i].lower + 1
if prev_upper >= lower:
return True
return False
def ranges_to_string(ranges):
"""
Generate a human-friendly string from a set of ranges. Intended for use with ArrayField. For example:
[[1, 100)], [200, 300)] => "1-99,200-299"
"""
if not ranges:
return ''
output = []
for r in ranges:
lower = r.lower if r.lower_inc else r.lower + 1
upper = r.upper if r.upper_inc else r.upper - 1
output.append(f'{lower}-{upper}')
return ','.join(output)
def string_to_ranges(value):
"""
Given a string in the format "1-100, 200-300" return an list of NumericRanges. Intended for use with ArrayField.
For example:
"1-99,200-299" => [NumericRange(1, 100), NumericRange(200, 300)]
"""
if not value:
return None
value.replace(' ', '') # Remove whitespace
values = []
for dash_range in value.split(','):
if '-' not in dash_range:
return None
lower, upper = dash_range.split('-')
values.append(NumericRange(int(lower), int(upper), bounds='[]'))
return values

View File

@@ -1,7 +1,10 @@
import datetime
from django.utils import timezone
from django.utils.timezone import localtime
__all__ = (
'datetime_from_timestamp',
'local_now',
)
@@ -11,3 +14,15 @@ def local_now():
Return the current date & time in the system timezone.
"""
return localtime(timezone.now())
def datetime_from_timestamp(value):
"""
Convert an ISO 8601 or RFC 3339 timestamp to a datetime object.
"""
# Work around UTC issue for Python < 3.11; see
# https://docs.python.org/3/library/datetime.html#datetime.datetime.fromisoformat
# TODO: Remove this once Python 3.10 is no longer supported
if type(value) is str and value.endswith('Z'):
value = f'{value[:-1]}+00:00'
return datetime.datetime.fromisoformat(value)

View File

@@ -39,7 +39,7 @@ def handle_protectederror(obj_list, request, e):
if hasattr(dependent, 'get_absolute_url'):
dependent_objects.append(f'<a href="{dependent.get_absolute_url()}">{escape(dependent)}</a>')
else:
dependent_objects.append(str(dependent))
dependent_objects.append(escape(str(dependent)))
err_message += ', '.join(dependent_objects)
messages.error(request, mark_safe(err_message))
@@ -49,11 +49,11 @@ def handle_rest_api_exception(request, *args, **kwargs):
"""
Handle exceptions and return a useful error message for REST API requests.
"""
type_, error, traceback = sys.exc_info()
type_, error = sys.exc_info()[:2]
data = {
'error': str(error),
'exception': type_.__name__,
'netbox_version': settings.VERSION,
'netbox_version': settings.RELEASE.full_version,
'python_version': platform.python_version(),
}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@@ -2,9 +2,9 @@ from collections import defaultdict
from django.contrib.contenttypes.fields import GenericForeignKey
from django.db import models
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from utilities.ordering import naturalize
from .forms.widgets import ColorSelect
from .validators import ColorValidator
@@ -26,6 +26,7 @@ class ColorField(models.CharField):
def formfield(self, **kwargs):
kwargs['widget'] = ColorSelect
kwargs['help_text'] = mark_safe(_('RGB color in hexadecimal. Example: ') + '<code>00ff00</code>')
return super().formfield(**kwargs)
@@ -38,7 +39,7 @@ class NaturalOrderingField(models.CharField):
"""
description = "Stores a representation of its target field suitable for natural ordering"
def __init__(self, target_field, naturalize_function=naturalize, *args, **kwargs):
def __init__(self, target_field, naturalize_function, *args, **kwargs):
self.target_field = target_field
self.naturalize_function = naturalize_function
super().__init__(*args, **kwargs)
@@ -70,14 +71,24 @@ class RestrictedGenericForeignKey(GenericForeignKey):
# 1. Capture restrict_params from RestrictedPrefetch (hack)
# 2. If restrict_params is set, call restrict() on the queryset for
# the related model
def get_prefetch_queryset(self, instances, queryset=None):
def get_prefetch_querysets(self, instances, querysets=None):
restrict_params = {}
custom_queryset_dict = {}
# Compensate for the hack in RestrictedPrefetch
if type(queryset) is dict:
restrict_params = queryset
elif queryset is not None:
raise ValueError(_("Custom queryset can't be used for this lookup."))
if type(querysets) is dict:
restrict_params = querysets
elif querysets is not None:
for queryset in querysets:
ct_id = self.get_content_type(
model=queryset.query.model, using=queryset.db
).pk
if ct_id in custom_queryset_dict:
raise ValueError(
"Only one queryset is allowed for each content type."
)
custom_queryset_dict[ct_id] = queryset
# For efficiency, group the instances by content type and then do one
# query per model
@@ -100,15 +111,16 @@ class RestrictedGenericForeignKey(GenericForeignKey):
ret_val = []
for ct_id, fkeys in fk_dict.items():
instance = instance_dict[ct_id]
ct = self.get_content_type(id=ct_id, using=instance._state.db)
if restrict_params:
# Override the default behavior to call restrict() on each model's queryset
qs = ct.model_class().objects.filter(pk__in=fkeys).restrict(**restrict_params)
ret_val.extend(qs)
if ct_id in custom_queryset_dict:
# Return values from the custom queryset, if provided.
ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
else:
# Default behavior
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
instance = instance_dict[ct_id]
ct = self.get_content_type(id=ct_id, using=instance._state.db)
qs = ct.model_class().objects.filter(pk__in=fkeys)
if restrict_params:
qs = qs.restrict(**restrict_params)
ret_val.extend(qs)
# For doing the join in Python, we have to match both the FK val and the
# content type, so we use a callable that returns a (fk, class) pair.

View File

@@ -3,8 +3,8 @@ from django import forms
from django.conf import settings
from django.core.exceptions import ValidationError
from django_filters.constants import EMPTY_VALUES
from drf_spectacular.utils import extend_schema_field
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
__all__ = (
'ContentTypeFilter',
@@ -116,6 +116,7 @@ class MultiValueWWNFilter(django_filters.MultipleChoiceFilter):
field_class = multivalue_field_factory(forms.CharField)
@extend_schema_field(OpenApiTypes.STR)
class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
"""
Filters for a set of Models, including all descendant models within a Tree. Example: [<Region: R1>,<Region: R2>]

View File

@@ -1,11 +1,14 @@
from django import forms
from django.contrib.postgres.forms import SimpleArrayField
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from utilities.data import ranges_to_string, string_to_ranges
from ..utils import parse_numeric_range
__all__ = (
'NumericArrayField',
'NumericRangeArrayField',
)
@@ -24,3 +27,31 @@ class NumericArrayField(SimpleArrayField):
if isinstance(value, str):
value = ','.join([str(n) for n in parse_numeric_range(value)])
return super().to_python(value)
class NumericRangeArrayField(forms.CharField):
"""
A field which allows for array of numeric ranges:
Example: 1-5,7-20,30-50
"""
def __init__(self, *args, help_text='', **kwargs):
if not help_text:
help_text = mark_safe(
_("Specify one or more numeric ranges separated by commas. Example: " + "<code>1-5,20-30</code>")
)
super().__init__(*args, help_text=help_text, **kwargs)
def clean(self, value):
if value and not self.to_python(value):
raise forms.ValidationError(
_("Invalid ranges ({value}). Must be a range of integers in ascending order.").format(value=value)
)
return super().clean(value)
def prepare_value(self, value):
if isinstance(value, str):
return value
return ranges_to_string(value)
def to_python(self, value):
return string_to_ranges(value)

View File

@@ -1,7 +1,7 @@
from django import forms
from django.utils.translation import gettext_lazy as _
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist, FieldError
from django.db.models import Q
from utilities.choices import unpack_grouped_choices
@@ -64,6 +64,10 @@ class CSVModelChoiceField(forms.ModelChoiceField):
raise forms.ValidationError(
_('"{value}" is not a unique value for this field; multiple objects were found').format(value=value)
)
except FieldError:
raise forms.ValidationError(
_('"{field_name}" is an invalid accessor field name.').format(field_name=self.to_field_name)
)
class CSVModelMultipleChoiceField(forms.ModelMultipleChoiceField):

View File

@@ -2,7 +2,7 @@ import django_filters
from django import forms
from django.conf import settings
from django.forms import BoundField
from django.urls import reverse
from django.urls import reverse, reverse_lazy
from utilities.forms import widgets
from utilities.views import get_viewname
@@ -66,6 +66,10 @@ class DynamicModelChoiceMixin:
choice (DEPRECATED: pass `context={'disabled': '$fieldname'}` instead)
context: A mapping of <option> template variables to their API data keys (optional; see below)
selector: Include an advanced object selection widget to assist the user in identifying the desired object
quick_add: Include a widget to quickly create a new related object for assignment. NOTE: Nested usage of
quick-add fields is not currently supported.
quick_add_params: A dictionary of initial data to include when launching the quick-add form (optional). The
token string "$pk" will be replaced with the primary key of the form's instance, if any.
Context keys:
value: The name of the attribute which contains the option's value (default: 'id')
@@ -90,6 +94,8 @@ class DynamicModelChoiceMixin:
disabled_indicator=None,
context=None,
selector=False,
quick_add=False,
quick_add_params=None,
**kwargs
):
self.model = queryset.model
@@ -99,6 +105,8 @@ class DynamicModelChoiceMixin:
self.disabled_indicator = disabled_indicator
self.context = context or {}
self.selector = selector
self.quick_add = quick_add
self.quick_add_params = quick_add_params or {}
super().__init__(queryset, **kwargs)
@@ -113,11 +121,6 @@ class DynamicModelChoiceMixin:
for var, accessor in self.context.items():
attrs[f'ts-{var}-field'] = accessor
# TODO: Remove in v4.1
# Legacy means of specifying the disabled indicator
if self.disabled_indicator is not None:
attrs['ts-disabled-field'] = self.disabled_indicator
# Attach any static query parameters
if len(self.query_params) > 0:
widget.add_query_params(self.query_params)
@@ -147,7 +150,7 @@ class DynamicModelChoiceMixin:
if data:
# When the field is multiple choice pass the data as a list if it's not already
if isinstance(bound_field.field, DynamicModelMultipleChoiceField) and not type(data) is list:
if isinstance(bound_field.field, DynamicModelMultipleChoiceField) and type(data) is not list:
data = [data]
field_name = getattr(self, 'to_field_name') or 'pk'
@@ -166,6 +169,22 @@ class DynamicModelChoiceMixin:
viewname = get_viewname(self.queryset.model, action='list', rest_api=True)
widget.attrs['data-url'] = reverse(viewname)
# Include quick add?
if self.quick_add:
app_label = self.model._meta.app_label
model_name = self.model._meta.model_name
widget.quick_add_context = {
'url': reverse_lazy(f'{app_label}:{model_name}_add'),
'params': {},
}
for k, v in self.quick_add_params.items():
if v == '$pk':
# Replace "$pk" token with the primary key of the form's instance (if any)
if getattr(form.instance, 'pk', None):
widget.quick_add_context['params'][k] = form.instance.pk
else:
widget.quick_add_context['params'][k] = v
return bound_field
@@ -197,6 +216,6 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip
# string 'null'. This will check for that condition and gracefully handle the conversion to a NoneType.
if self.null_option is not None and settings.FILTERS_NULL_CHOICE_VALUE in value:
value = [v for v in value if v != settings.FILTERS_NULL_CHOICE_VALUE]
return [None, *value]
return [None, *super().clean(value)]
return super().clean(value)

View File

@@ -22,6 +22,15 @@ class APISelect(forms.Select):
dynamic_params: Dict[str, str]
static_params: Dict[str, List[str]]
def get_context(self, name, value, attrs):
context = super().get_context(name, value, attrs)
# Add quick-add context data, if enabled for the widget
if hasattr(self, 'quick_add_context'):
context['quick_add'] = self.quick_add_context
return context
def __init__(self, api_url=None, full=False, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -43,9 +43,12 @@ class HTMXSelect(forms.Select):
"""
Selection widget that will re-generate the HTML form upon the selection of a new option.
"""
def __init__(self, hx_url='.', hx_target_id='form_fields', attrs=None, **kwargs):
def __init__(self, method='get', hx_url='.', hx_target_id='form_fields', attrs=None, **kwargs):
method = method.lower()
if method not in ('delete', 'get', 'patch', 'post', 'put'):
raise ValueError(f"Unsupported HTTP method: {method}")
_attrs = {
'hx-get': hx_url,
f'hx-{method}': hx_url,
'hx-include': f'#{hx_target_id}',
'hx-target': f'#{hx_target_id}',
}

View File

@@ -59,7 +59,7 @@ def highlight(value, highlight, trim_pre=None, trim_post=None, trim_placeholder=
else:
highlight = re.escape(highlight)
pre, match, post = re.split(fr'({highlight})', value, maxsplit=1, flags=re.IGNORECASE)
except ValueError as e:
except ValueError:
# Match not found
return escape(value)

View File

@@ -28,10 +28,14 @@ class DataFileLoader(BaseLoader):
raise TemplateNotFound(template)
# Find and pre-fetch referenced templates
if referenced_templates := find_referenced_templates(environment.parse(template_source)):
if referenced_templates := tuple(find_referenced_templates(environment.parse(template_source))):
related_files = DataFile.objects.filter(source=self.data_source)
# None indicates the use of dynamic resolution. If dependent files are statically
# defined, we can filter by path for optimization.
if None not in referenced_templates:
related_files = related_files.filter(path__in=referenced_templates)
self.cache_templates({
df.path: df.data_as_string for df in
DataFile.objects.filter(source=self.data_source, path__in=referenced_templates)
df.path: df.data_as_string for df in related_files
})
return template_source, template, lambda: True

View File

@@ -3,6 +3,7 @@ import decimal
from django.core.serializers.json import DjangoJSONEncoder
__all__ = (
'ConfigJSONEncoder',
'CustomFieldJSONEncoder',
)
@@ -15,3 +16,16 @@ class CustomFieldJSONEncoder(DjangoJSONEncoder):
if isinstance(o, decimal.Decimal):
return float(o)
return super().default(o)
class ConfigJSONEncoder(DjangoJSONEncoder):
"""
Override Django's built-in JSON encoder to serialize CustomValidator classes as strings.
"""
def default(self, o):
from extras.validators import CustomValidator
if issubclass(type(o), CustomValidator):
return type(o).__name__
return super().default(o)

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
from django.core.management.base import BaseCommand
from django.db.models import Count, OuterRef, Subquery
from netbox.registry import registry
from utilities.counters import update_counts

View File

@@ -14,7 +14,7 @@ class StrikethroughExtension(markdown.Extension):
"""
def extendMarkdown(self, md):
md.inlinePatterns.register(
markdown.inlinepatterns.SimpleTagPattern(STRIKE_RE, 'del'),
SimpleTagPattern(STRIKE_RE, 'del'),
'strikethrough',
200
)

View File

@@ -87,7 +87,7 @@ def get_paginate_count(request):
pass
if request.user.is_authenticated:
per_page = request.user.config.get('pagination.per_page', config.PAGINATE_COUNT)
per_page = request.user.config.get('pagination.per_page') or config.PAGINATE_COUNT
return _max_allowed(per_page)
return _max_allowed(config.PAGINATE_COUNT)

View File

@@ -0,0 +1,27 @@
from django.core.exceptions import ValidationError
from django.utils.translation import gettext as _
class AlphanumericPasswordValidator:
"""
Validate that the password has at least one numeral, one uppercase letter and one lowercase letter.
"""
def validate(self, password, user=None):
if not any(char.isdigit() for char in password):
raise ValidationError(
_("Password must have at least one numeral."),
)
if not any(char.isupper() for char in password):
raise ValidationError(
_("Password must have at least one uppercase letter."),
)
if not any(char.islower() for char in password):
raise ValidationError(
_("Password must have at least one lowercase letter."),
)
def get_help_text(self):
return _("Your password must contain at least one numeral, one uppercase letter and one lowercase letter.")

View File

@@ -1,7 +1,10 @@
from django.conf import settings
from django.apps import apps
from django.db.models import Q
from django.utils.translation import gettext_lazy as _
from users.constants import CONSTRAINT_TOKEN_USER
__all__ = (
'get_permission_for_model',
'permission_is_exempt',
@@ -90,6 +93,11 @@ def qs_filter_from_constraints(constraints, tokens=None):
if tokens is None:
tokens = {}
User = apps.get_model('users.User')
for token, value in tokens.items():
if token == CONSTRAINT_TOKEN_USER and isinstance(value, User):
tokens[token] = value.id
def _replace_tokens(value, tokens):
if type(value) is list:
return list(map(lambda v: tokens.get(v, v), value))

View File

@@ -55,7 +55,7 @@ def prepare_cloned_fields(instance):
for key, value in attrs.items():
if type(value) in (list, tuple):
params.extend([(key, v) for v in value])
elif value not in (False, None):
elif value is not False and value is not None:
params.append((key, value))
else:
params.append((key, ''))

View File

@@ -20,14 +20,14 @@ class RestrictedPrefetch(Prefetch):
super().__init__(lookup, queryset=queryset, to_attr=to_attr)
def get_current_queryset(self, level):
def get_current_querysets(self, level):
params = {
'user': self.restrict_user,
'action': self.restrict_action,
}
if qs := super().get_current_queryset(level):
return qs.restrict(**params)
if querysets := super().get_current_querysets(level):
return [qs.restrict(**params) for qs in querysets]
# Bit of a hack. If no queryset is defined, pass through the dict of restrict()
# kwargs to be handled by the field. This is necessary e.g. for GenericForeignKey
@@ -49,11 +49,11 @@ class RestrictedQuerySet(QuerySet):
permission_required = get_permission_for_model(self.model, action)
# Bypass restriction for superusers and exempt views
if user.is_superuser or permission_is_exempt(permission_required):
if user and user.is_superuser or permission_is_exempt(permission_required):
qs = self
# User is anonymous or has not been granted the requisite permission
elif not user.is_authenticated or permission_required not in user.get_all_permissions():
elif user is None or not user.is_authenticated or permission_required not in user.get_all_permissions():
qs = self.none()
# Filter the queryset to include only objects with allowed attributes

View File

@@ -0,0 +1,76 @@
import datetime
import os
from dataclasses import asdict, dataclass, field
from typing import Union
import yaml
from django.core.exceptions import ImproperlyConfigured
from utilities.datetime import datetime_from_timestamp
RELEASE_PATH = 'release.yaml'
LOCAL_RELEASE_PATH = 'local/release.yaml'
@dataclass
class FeatureSet:
"""
A map of all available NetBox features.
"""
# Commercial support is provided by NetBox Labs
commercial: bool = False
# Live help center is enabled
help_center: bool = False
@dataclass
class ReleaseInfo:
version: str
edition: str
published: Union[datetime.date, None] = None
designation: Union[str, None] = None
features: FeatureSet = field(default_factory=FeatureSet)
@property
def full_version(self):
if self.designation:
return f"{self.version}-{self.designation}"
return self.version
@property
def name(self):
return f"NetBox {self.edition} v{self.full_version}"
def asdict(self):
return asdict(self)
def load_release_data():
"""
Load any locally-defined release attributes and return a ReleaseInfo instance.
"""
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Load canonical release attributes
with open(os.path.join(base_path, RELEASE_PATH), 'r') as release_file:
data = yaml.safe_load(release_file)
# Overlay any local release date (if defined)
try:
with open(os.path.join(base_path, LOCAL_RELEASE_PATH), 'r') as release_file:
local_data = yaml.safe_load(release_file)
except FileNotFoundError:
local_data = {}
if local_data is not None:
if type(local_data) is not dict:
raise ImproperlyConfigured(
f"{LOCAL_RELEASE_PATH}: Local release data must be defined as a dictionary."
)
data.update(local_data)
# Convert the published date to a date object
if 'published' in data:
data['published'] = datetime_from_timestamp(data['published'])
return ReleaseInfo(**data)

View File

@@ -2,7 +2,6 @@ import json
from django.contrib.contenttypes.models import ContentType
from django.core import serializers
from mptt.models import MPTTModel
from extras.utils import is_taggable
@@ -16,8 +15,7 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
"""
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
implicitly excluded.
can be provided to exclude them from the returned dictionary.
Args:
obj: The object to serialize
@@ -30,11 +28,6 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
data = json.loads(json_str)[0]['fields']
exclude = exclude or []
# Exclude any MPTTModel fields
if issubclass(obj.__class__, MPTTModel):
for field in ['level', 'lft', 'rght', 'tree_id']:
data.pop(field)
# Include custom_field_data as "custom_fields"
if hasattr(obj, 'custom_field_data'):
data['custom_fields'] = data.pop('custom_field_data')
@@ -45,9 +38,9 @@ def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
tags = getattr(obj, '_tags', None) or obj.tags.all()
data['tags'] = sorted([tag.name for tag in tags])
# Skip excluded and private (prefixes with an underscore) attributes
# Skip any excluded attributes
for key in list(data.keys()):
if key in exclude or (isinstance(key, str) and key.startswith('_')):
if key in exclude:
data.pop(key)
# Append any extra data

View File

@@ -1,5 +1,5 @@
from django.contrib.postgres.fields import ArrayField
from django.core.serializers.json import Deserializer, Serializer as Serializer_ # noqa
from django.core.serializers.json import Deserializer, Serializer as Serializer_ # noqa: F401
from django.utils.encoding import is_protected_type
# NOTE: Module must contain both Serializer and Deserializer

104
netbox/utilities/socks.py Normal file
View File

@@ -0,0 +1,104 @@
import logging
from urllib.parse import urlparse
from urllib3 import PoolManager, HTTPConnectionPool, HTTPSConnectionPool
from urllib3.connection import HTTPConnection, HTTPSConnection
from .constants import HTTP_PROXY_SOCK_RDNS_SCHEMAS
logger = logging.getLogger('netbox.utilities')
class ProxyHTTPConnection(HTTPConnection):
"""
A Proxy connection class that uses a SOCK proxy - used to create
a urllib3 PoolManager that routes connections via the proxy.
This is for an HTTP (not HTTPS) connection
"""
use_rdns = False
def __init__(self, *args, **kwargs):
socks_options = kwargs.pop('_socks_options')
self._proxy_url = socks_options['proxy_url']
super().__init__(*args, **kwargs)
def _new_conn(self):
try:
from python_socks.sync import Proxy
except ModuleNotFoundError as e:
logger.info(
"Configuring an HTTP proxy using SOCKS requires the python_socks library. Check that it has been "
"installed."
)
raise e
proxy = Proxy.from_url(self._proxy_url, rdns=self.use_rdns)
return proxy.connect(
dest_host=self.host,
dest_port=self.port,
timeout=self.timeout
)
class ProxyHTTPSConnection(ProxyHTTPConnection, HTTPSConnection):
"""
A Proxy connection class for an HTTPS (not HTTP) connection.
"""
pass
class RdnsProxyHTTPConnection(ProxyHTTPConnection):
"""
A Proxy connection class for an HTTP remote-dns connection.
I.E. socks4a, socks4h, socks5a, socks5h
"""
use_rdns = True
class RdnsProxyHTTPSConnection(ProxyHTTPSConnection):
"""
A Proxy connection class for an HTTPS remote-dns connection.
I.E. socks4a, socks4h, socks5a, socks5h
"""
use_rdns = True
class ProxyHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = ProxyHTTPConnection
class ProxyHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = ProxyHTTPSConnection
class RdnsProxyHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = RdnsProxyHTTPConnection
class RdnsProxyHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = RdnsProxyHTTPSConnection
class ProxyPoolManager(PoolManager):
def __init__(self, proxy_url, timeout=5, num_pools=10, headers=None, **connection_pool_kw):
# python_socks uses rdns param to denote remote DNS parsing and
# doesn't accept the 'h' or 'a' in the proxy URL
if use_rdns := urlparse(proxy_url).scheme in HTTP_PROXY_SOCK_RDNS_SCHEMAS:
proxy_url = proxy_url.replace('socks5h:', 'socks5:').replace('socks5a:', 'socks5:')
proxy_url = proxy_url.replace('socks4h:', 'socks4:').replace('socks4a:', 'socks4:')
connection_pool_kw['_socks_options'] = {'proxy_url': proxy_url}
connection_pool_kw['timeout'] = timeout
super().__init__(num_pools, headers, **connection_pool_kw)
if use_rdns:
self.pool_classes_by_scheme = {
'http': RdnsProxyHTTPConnectionPool,
'https': RdnsProxyHTTPSConnectionPool,
}
else:
self.pool_classes_by_scheme = {
'http': ProxyHTTPConnectionPool,
'https': ProxyHTTPSConnectionPool,
}

View File

@@ -29,7 +29,7 @@ def linkify_phone(value):
"""
if value is None:
return None
return f"tel:{value}"
return f"tel:{value.replace(' ', '')}"
def register_table_column(column, name, *tables):

View File

@@ -11,7 +11,7 @@
{% elif customfield.type == 'date' and value %}
{{ value|isodate }}
{% elif customfield.type == 'datetime' and value %}
{{ value|isodate }} {{ value|isodatetime }}
{{ value|isodatetime }}
{% elif customfield.type == 'url' and value %}
<a href="{{ value }}">{{ value|truncatechars:70 }}</a>
{% elif customfield.type == 'json' and value %}

View File

@@ -1,5 +1,5 @@
<div class="htmx-container table-responsive"
hx-get="{% url viewname %}{% if url_params %}?{{ url_params.urlencode }}{% endif %}"
hx-get="{% url viewname %}?embedded=True{% if url_params %}&{{ url_params.urlencode }}{% endif %}"
hx-target="this"
hx-trigger="load" hx-select="table" hx-swap="innerHTML"
hx-trigger="load" hx-select=".htmx-container" hx-swap="outerHTML"
></div>

View File

@@ -10,7 +10,7 @@
</button>
{% else %}
<button type="submit" class="btn btn-cyan">
<i class="mdi mdi-bookmark-check"></i> {% trans "Bookmark" %}
<i class="mdi mdi-bookmark-plus"></i> {% trans "Bookmark" %}
</button>
{% endif %}
</form>

View File

@@ -4,7 +4,7 @@
<i class="mdi mdi-download"></i> {% trans "Export" %}
</button>
<ul class="dropdown-menu dropdown-menu-end">
<li><a class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export=table">{% trans "Current View" %}</a></li>
<li><a id="export_current_view" class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export=table">{% trans "Current View" %}</a></li>
<li><a class="dropdown-item" href="?{% if url_params %}{{ url_params }}&{% endif %}export">{% trans "All Data" %} ({{ data_format }})</a></li>
{% if export_templates %}
<li>

View File

@@ -0,0 +1,18 @@
{% load i18n %}
{% if form_url %}
<form action="{{ form_url }}?return_url={{ return_url }}" method="post">
{% csrf_token %}
{% for field, value in form_data.items %}
<input type="hidden" name="{{ field }}" value="{{ value }}" />
{% endfor %}
{% if subscription %}
<button type="submit" class="btn btn-cyan">
<i class="mdi mdi-bell-minus"></i> {% trans "Unsubscribe" %}
</button>
{% else %}
<button type="submit" class="btn btn-cyan">
<i class="mdi mdi-bell-plus"></i> {% trans "Subscribe" %}
</button>
{% endif %}
</form>
{% endif %}

View File

@@ -3,7 +3,7 @@
{% for group, fields in form.custom_field_groups.items %}
{% if group %}
<div class="row">
<h6 class="offset-sm-3 mb-3">{{ group }}</h6>
<h3 class="col-9 offset-3 mb-3 h4">{{ group }}</h3>
</div>
{% endif %}
{% for name in fields %}

View File

@@ -6,9 +6,11 @@
{# Render the field label (if any), except for checkboxes #}
{% if label and not field|widget_type == 'checkboxinput' %}
<label for="{{ field.id_for_label }}" class="col-sm-3 col-form-label text-lg-end{% if field.field.required %} required{% endif %}">
{{ label }}
</label>
<div class="col-sm-3 text-lg-end">
<label for="{{ field.id_for_label }}" class="col-form-label d-inline-block{% if field.field.required %} required{% endif %}">
{{ label }}
</label>
</div>
{% endif %}
{# Render the field itself #}

View File

@@ -3,7 +3,7 @@
<div class="field-group mb-5">
{% if heading %}
<div class="row">
<h5 class="col-9 offset-3">{{ heading }}</h5>
<h2 class="col-9 offset-3">{{ heading }}</h2>
</div>
{% endif %}
{% for layout, title, items in rows %}

View File

@@ -16,10 +16,10 @@
</div>
<div class="col-2 d-flex align-items-center">
<div>
<a class="btn btn-success btn-sm w-100 my-2" id="add_columns">
<a tabindex="0" class="btn btn-success btn-sm w-100 my-2" id="add_columns">
<i class="mdi mdi-arrow-right-bold"></i> {% trans "Add" %}
</a>
<a class="btn btn-danger btn-sm w-100 my-2" id="remove_columns">
<a tabindex="0" class="btn btn-danger btn-sm w-100 my-2" id="remove_columns">
<i class="mdi mdi-arrow-left-bold"></i> {% trans "Remove" %}
</a>
</div>
@@ -27,10 +27,10 @@
<div class="col-5 text-center">
{{ form.columns.label }}
{{ form.columns }}
<a class="btn btn-primary btn-sm mt-2" id="move-option-up" data-target="id_columns">
<a tabindex="0" class="btn btn-primary btn-sm mt-2" id="move-option-up" data-target="id_columns">
<i class="mdi mdi-arrow-up-bold"></i> {% trans "Move Up" %}
</a>
<a class="btn btn-primary btn-sm mt-2" id="move-option-down" data-target="id_columns">
<a tabindex="0" class="btn btn-primary btn-sm mt-2" id="move-option-down" data-target="id_columns">
<i class="mdi mdi-arrow-down-bold"></i> {% trans "Move Down" %}
</a>
</div>

View File

@@ -1,7 +1,23 @@
{% load helpers %}
{% load i18n %}
{% load navigation %}
<ul class="navbar-nav pt-lg-2" {% htmx_boost %}>
<li class="nav-item d-block d-lg-none">
<form action="{% url 'search' %}" method="get" autocomplete="off" novalidate>
<div class="input-group mb-1 mt-2">
<div class="input-group-prepend">
<span class="input-group-text">
<i class="mdi mdi-magnify"></i>
</span>
</div>
<input type="text" name="q" value="" class="form-control" placeholder="{% trans "Search" %}" aria-label="{% trans "Search NetBox" %}">
<div class="input-group-append">
<button type="submit" class="form-control">{% trans "Search" %}</button>
</div>
</div>
</form>
</li>
{% for menu, groups in nav_items %}
<li class="nav-item dropdown">
@@ -29,7 +45,7 @@
{% if buttons %}
<div class="btn-group ms-1">
{% for button in buttons %}
<a href="{% url button.link %}" class="btn btn-sm btn-{{ button.color|default:"outline-dark" }} lh-2 px-2" title="{{ button.title }}">
<a href="{% url button.link %}" class="btn btn-sm btn-{{ button.color|default:"outline" }} lh-2 px-2" title="{{ button.title }}">
<i class="{{ button.icon_class }}"></i>
</a>
{% endfor %}

View File

@@ -1,7 +1,8 @@
{% load i18n %}
{% if widget.attrs.selector and not widget.attrs.disabled %}
<div class="d-flex">
{% include 'django/forms/widgets/select.html' %}
<div class="d-flex">
{% include 'django/forms/widgets/select.html' %}
{% if widget.attrs.selector and not widget.attrs.disabled %}
{# Opens the object selector modal #}
<button
type="button"
title="{% trans "Open selector" %}"
@@ -13,7 +14,19 @@
>
<i class="mdi mdi-database-search-outline"></i>
</button>
</div>
{% else %}
{% include 'django/forms/widgets/select.html' %}
{% endif %}
{% endif %}
{% if quick_add and not widget.attrs.disabled %}
{# Opens the quick add modal #}
<button
type="button"
title="{% trans "Quick add" %}"
class="btn btn-outline-secondary ms-1"
data-bs-toggle="modal"
data-bs-target="#htmx-modal"
hx-get="{{ quick_add.url }}?_quickadd=True&target={{ widget.attrs.id }}{% for k, v in quick_add.params.items %}&{{ k }}={{ v }}{% endfor %}"
hx-target="#htmx-modal-content"
>
<i class="mdi mdi-plus-circle"></i>
</button>
{% endif %}
</div>

View File

@@ -8,6 +8,7 @@ from django.contrib.contenttypes.models import ContentType
from django.contrib.humanize.templatetags.humanize import naturalday, naturaltime
from django.utils.html import escape
from django.utils.safestring import mark_safe
from django.utils.timezone import localtime
from markdown import markdown
from markdown.extensions.tables import TableExtension
@@ -58,7 +59,7 @@ def linkify(instance, attr=None):
url = instance.get_absolute_url()
return mark_safe(f'<a href="{url}">{escape(text)}</a>')
except (AttributeError, TypeError):
return text
return escape(text)
@register.filter()
@@ -218,7 +219,8 @@ def isodate(value):
text = value.isoformat()
return mark_safe(f'<span title="{naturalday(value)}">{text}</span>')
elif type(value) is datetime.datetime:
text = value.date().isoformat()
local_value = localtime(value) if value.tzinfo else value
text = local_value.date().isoformat()
return mark_safe(f'<span title="{naturaltime(value)}">{text}</span>')
else:
return ''
@@ -229,7 +231,8 @@ def isotime(value, spec='seconds'):
if type(value) is datetime.time:
return value.isoformat(timespec=spec)
if type(value) is datetime.datetime:
return value.time().isoformat(timespec=spec)
local_value = localtime(value) if value.tzinfo else value
return local_value.time().isoformat(timespec=spec)
return ''

View File

@@ -1,4 +1,5 @@
from django import template
from django.utils.safestring import mark_safe
from extras.choices import CustomFieldTypeChoices
from utilities.querydict import dict_to_querydict
@@ -124,5 +125,5 @@ def formaction(context):
if HTMX navigation is enabled (per the user's preferences).
"""
if context.get('htmx_navigation', False):
return 'hx-push-url="true" hx-post'
return mark_safe('hx-push-url="true" hx-post')
return 'formaction'

View File

@@ -3,7 +3,8 @@ from django.contrib.contenttypes.models import ContentType
from django.urls import NoReverseMatch, reverse
from core.models import ObjectType
from extras.models import Bookmark, ExportTemplate
from extras.models import Bookmark, ExportTemplate, Subscription
from netbox.models.features import NotificationsMixin
from utilities.querydict import prepare_cloned_fields
from utilities.views import get_viewname
@@ -17,6 +18,7 @@ __all__ = (
'edit_button',
'export_button',
'import_button',
'subscribe_button',
'sync_button',
)
@@ -94,6 +96,41 @@ def delete_button(instance):
}
@register.inclusion_tag('buttons/subscribe.html', takes_context=True)
def subscribe_button(context, instance):
# Skip for objects which don't support notifications
if not (issubclass(instance.__class__, NotificationsMixin)):
return {}
# Check if this user has already subscribed to the object
content_type = ContentType.objects.get_for_model(instance)
subscription = Subscription.objects.filter(
object_type=content_type,
object_id=instance.pk,
user=context['request'].user
).first()
# Compile form URL & data
if subscription:
form_url = reverse('extras:subscription_delete', kwargs={'pk': subscription.pk})
form_data = {
'confirm': 'true',
}
else:
form_url = reverse('extras:subscription_add')
form_data = {
'object_type': content_type.pk,
'object_id': instance.pk,
}
return {
'subscription': subscription,
'form_url': form_url,
'form_data': form_data,
'return_url': instance.get_absolute_url(),
}
@register.inclusion_tag('buttons/sync.html')
def sync_button(instance):
viewname = get_viewname(instance, 'sync')
@@ -121,7 +158,7 @@ def add_button(model, action='add'):
@register.inclusion_tag('buttons/import.html')
def import_button(model, action='import'):
def import_button(model, action='bulk_import'):
try:
url = reverse(get_viewname(model, action))
except NoReverseMatch:

View File

@@ -1,8 +1,6 @@
import warnings
from django import template
from utilities.forms.rendering import FieldSet, InlineFields, ObjectAttribute, TabbedGroups
from utilities.forms.rendering import InlineFields, ObjectAttribute, TabbedGroups
__all__ = (
'getfield',
@@ -54,15 +52,6 @@ def render_fieldset(form, fieldset):
"""
Render a group set of fields.
"""
# TODO: Remove in NetBox v4.1
# Handle legacy tuple-based fieldset definitions, e.g. (_('Label'), ('field1, 'field2', 'field3'))
if type(fieldset) is not FieldSet:
warnings.warn(
f"{form.__class__} fieldsets contains a non-FieldSet item: {fieldset}"
)
name, fields = fieldset
fieldset = FieldSet(*fields, name=name)
rows = []
for item in fieldset.items:

View File

@@ -1,14 +1,9 @@
import datetime
import json
from typing import Dict, Any
from urllib.parse import quote
from django import template
from django.conf import settings
from django.template.defaultfilters import date
from django.urls import NoReverseMatch, reverse
from django.utils import timezone
from django.utils.safestring import mark_safe
from core.models import ObjectType
from utilities.forms import get_selected_values, TableConfigForm
@@ -92,15 +87,22 @@ def humanize_speed(speed):
@register.filter()
def humanize_megabytes(mb):
"""
Express a number of megabytes in the most suitable unit (e.g. gigabytes or terabytes).
Express a number of megabytes in the most suitable unit (e.g. gigabytes, terabytes, etc.).
"""
if not mb:
return ''
if not mb % 1048576: # 1024^2
return f'{int(mb / 1048576)} TB'
if not mb % 1024:
return f'{int(mb / 1024)} GB'
return f'{mb} MB'
return ""
PB_SIZE = 1000000000
TB_SIZE = 1000000
GB_SIZE = 1000
if mb >= PB_SIZE:
return f"{mb / PB_SIZE:.2f} PB"
if mb >= TB_SIZE:
return f"{mb / TB_SIZE:.2f} TB"
if mb >= GB_SIZE:
return f"{mb / GB_SIZE:.2f} GB"
return f"{mb} MB"
@register.filter()
@@ -279,6 +281,10 @@ def applied_filters(context, model, form, query_params):
if filter_name not in querydict:
continue
# Skip saved filters, as they're displayed alongside the quick search widget
if filter_name == 'filter_id':
continue
bound_field = form.fields[filter_name].get_bound_field(form, filter_name)
querydict.pop(filter_name)
display_value = ', '.join([str(v) for v in get_selected_values(form, filter_name)])

View File

@@ -1,4 +1,5 @@
from django import template
from django.utils.html import escape
from django.utils.safestring import mark_safe
register = template.Library()
@@ -15,6 +16,6 @@ def nested_tree(obj):
nodes = obj.get_ancestors(include_self=True)
return mark_safe(
' / '.join(
f'<a href="{node.get_absolute_url()}">{node}</a>' for node in nodes
f'<a href="{node.get_absolute_url()}">{escape(node)}</a>' for node in nodes
)
)

View File

@@ -22,8 +22,10 @@ def _get_registered_content(obj, method, template_context):
'perms': template_context['perms'],
}
model_name = obj._meta.label_lower
template_extensions = registry['plugins']['template_extensions'].get(model_name, [])
template_extensions = list(registry['plugins']['template_extensions'].get(None, []))
if hasattr(obj, '_meta'):
model_name = obj._meta.label_lower
template_extensions.extend(registry['plugins']['template_extensions'].get(model_name, []))
for template_extension in template_extensions:
# If the class has not overridden the specified method, we can skip it (because we know it
@@ -43,6 +45,22 @@ def _get_registered_content(obj, method, template_context):
return mark_safe(html)
@register.simple_tag(takes_context=True)
def plugin_navbar(context):
"""
Render any navbar content embedded by plugins
"""
return _get_registered_content(None, 'navbar', context)
@register.simple_tag(takes_context=True)
def plugin_list_buttons(context, model):
"""
Render all list buttons registered by plugins
"""
return _get_registered_content(model, 'list_buttons', context)
@register.simple_tag(takes_context=True)
def plugin_buttons(context, obj):
"""
@@ -51,6 +69,14 @@ def plugin_buttons(context, obj):
return _get_registered_content(obj, 'buttons', context)
@register.simple_tag(takes_context=True)
def plugin_alerts(context, obj):
"""
Render all object alerts registered by plugins
"""
return _get_registered_content(obj, 'alerts', context)
@register.simple_tag(takes_context=True)
def plugin_left_page(context, obj):
"""
@@ -73,11 +99,3 @@ def plugin_full_width_page(context, obj):
Render all full width page content registered by plugins
"""
return _get_registered_content(obj, 'full_width_page', context)
@register.simple_tag(takes_context=True)
def plugin_list_buttons(context, model):
"""
Render all list buttons registered by plugins
"""
return _get_registered_content(model, 'list_buttons', context)

View File

@@ -1,28 +1,24 @@
import inspect
import json
import strawberry_django
import strawberry_django
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from django.urls import reverse
from django.test import override_settings
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient
from strawberry.types.base import StrawberryList, StrawberryOptional
from strawberry.types.lazy_type import LazyType
from strawberry.types.union import StrawberryUnion
from core.models import ObjectType
from extras.choices import ObjectChangeActionChoices
from extras.models import ObjectChange
from users.models import ObjectPermission, Token
from core.choices import ObjectChangeActionChoices
from core.models import ObjectChange, ObjectType
from ipam.graphql.types import IPAddressFamilyType
from users.models import ObjectPermission, Token, User
from utilities.api import get_graphql_type_for_model
from .base import ModelTestCase
from .utils import disable_warnings
from ipam.graphql.types import IPAddressFamilyType
from strawberry.field import StrawberryField
from strawberry.lazy_type import LazyType
from strawberry.type import StrawberryList, StrawberryOptional
from strawberry.union import StrawberryUnion
from .utils import disable_logging, disable_warnings
__all__ = (
'APITestCase',
@@ -30,9 +26,6 @@ __all__ = (
)
User = get_user_model()
#
# REST/GraphQL API Tests
#
@@ -73,7 +66,7 @@ class APIViewTestCases:
class GetObjectViewTestCase(APITestCase):
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
def test_get_object_anonymous(self):
"""
GET a single object as an unauthenticated user.
@@ -135,7 +128,7 @@ class APIViewTestCases:
class ListObjectsViewTestCase(APITestCase):
brief_fields = []
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
def test_list_objects_anonymous(self):
"""
GET a list of objects as an unauthenticated user.
@@ -440,18 +433,20 @@ class APIViewTestCases:
base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
return getattr(self, 'graphql_base_name', base_name)
def _build_query(self, name, **filters):
def _build_query_with_filter(self, name, filter_string):
"""
Called by either _build_query or _build_filtered_query - construct the actual
query given a name and filter string
"""
type_class = get_graphql_type_for_model(self.model)
if filters:
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
filter_string = f'({filter_string})'
else:
filter_string = ''
# Compile list of fields to include
fields_string = ''
file_fields = (strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType)
file_fields = (
strawberry_django.fields.types.DjangoFileType,
strawberry_django.fields.types.DjangoImageType,
)
for field in type_class.__strawberry_definition__.fields:
if (
field.type in file_fields or (
@@ -492,8 +487,39 @@ class APIViewTestCases:
return query
def _build_filtered_query(self, name, **filters):
"""
Create a filtered query: i.e. device_list(filters: {name: {i_contains: "akron"}}){.
"""
# TODO: This should be extended to support AND, OR multi-lookups
if filters:
for field_name, params in filters.items():
lookup = params['lookup']
value = params['value']
if lookup:
query = f'{{{lookup}: "{value}"}}'
filter_string = f'{field_name}: {query}'
else:
filter_string = f'{field_name}: "{value}"'
filter_string = f'(filters: {{{filter_string}}})'
else:
filter_string = ''
return self._build_query_with_filter(name, filter_string)
def _build_query(self, name, **filters):
"""
Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
"""
if filters:
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
filter_string = f'({filter_string})'
else:
filter_string = ''
return self._build_query_with_filter(name, filter_string)
@override_settings(LOGIN_REQUIRED=True)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
def test_graphql_get_object(self):
url = reverse('graphql')
field_name = self._get_graphql_base_name()
@@ -501,39 +527,92 @@ class APIViewTestCases:
query = self._build_query(field_name, id=object_id)
# Non-authenticated requests should fail
header = {
'HTTP_ACCEPT': 'application/json',
}
with disable_warnings('django.request'):
header = {
'HTTP_ACCEPT': 'application/json',
}
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
response = self.client.post(url, data={'query': query}, format="json", **header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
# Add object-level permission
# Add constrained permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['view']
actions=['view'],
constraints={'id': 0} # Impossible constraint
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Request should succeed but return empty result
with disable_logging():
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertIn('errors', data)
self.assertIsNone(data['data'])
# Remove permission constraint
obj_perm.constraints = None
obj_perm.save()
# Request should return requested object
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertIsNotNone(data['data'])
@override_settings(LOGIN_REQUIRED=True)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
def test_graphql_list_objects(self):
url = reverse('graphql')
field_name = f'{self._get_graphql_base_name()}_list'
query = self._build_query(field_name)
# Non-authenticated requests should fail
header = {
'HTTP_ACCEPT': 'application/json',
}
with disable_warnings('django.request'):
header = {
'HTTP_ACCEPT': 'application/json',
}
self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN)
response = self.client.post(url, data={'query': query}, format="json", **header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
# Add constrained permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['view'],
constraints={'id': 0} # Impossible constraint
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Request should succeed but return empty results list
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data'][field_name]), 0)
# Remove permission constraint
obj_perm.constraints = None
obj_perm.save()
# Request should return all objects
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data'][field_name]), self.model.objects.count())
@override_settings(LOGIN_REQUIRED=True)
def test_graphql_filter_objects(self):
if not hasattr(self, 'graphql_filter'):
return
url = reverse('graphql')
field_name = f'{self._get_graphql_base_name()}_list'
query = self._build_filtered_query(field_name, **self.graphql_filter)
# Add object-level permission
obj_perm = ObjectPermission(

View File

@@ -1,8 +1,8 @@
import json
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.fields import ArrayField, RangeField
from django.core.exceptions import FieldDoesNotExist
from django.db.models import ManyToManyField, ManyToManyRel, JSONField
from django.forms.models import model_to_dict
@@ -11,7 +11,8 @@ from netaddr import IPNetwork
from taggit.managers import TaggableManager
from core.models import ObjectType
from users.models import ObjectPermission
from users.models import ObjectPermission, User
from utilities.data import ranges_to_string
from utilities.object_types import object_type_identifier
from utilities.permissions import resolve_permission_type
from .utils import DUMMY_CF_DATA, extract_form_failures
@@ -28,7 +29,7 @@ class TestCase(_TestCase):
def setUp(self):
# Create the test user and assign permissions
self.user = get_user_model().objects.create_user(username='testuser')
self.user = User.objects.create_user(username='testuser')
self.add_permissions(*self.user_permissions)
# Initialize the test client
@@ -120,6 +121,10 @@ class ModelTestCase(TestCase):
else:
model_dict[key] = sorted([obj.pk for obj in value])
# Handle GenericForeignKeys
elif value and type(field) is GenericForeignKey:
model_dict[key] = value.pk
elif api:
# Replace ContentType numeric IDs with <app_label>.<model>
@@ -136,9 +141,15 @@ class ModelTestCase(TestCase):
# Convert ArrayFields to CSV strings
if type(field) is ArrayField:
if type(field.base_field) is ArrayField:
if getattr(field.base_field, 'choices', None):
# Values for fields with pre-defined choices can be returned as lists
model_dict[key] = value
elif type(field.base_field) is ArrayField:
# Handle nested arrays (e.g. choice sets)
model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
elif issubclass(type(field.base_field), RangeField):
# Handle arrays of numeric ranges (e.g. VLANGroup VLAN ID ranges)
model_dict[key] = ranges_to_string(value)
else:
model_dict[key] = ','.join([str(v) for v in value])

View File

@@ -3,7 +3,6 @@ import logging
import re
from contextlib import contextmanager
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission
from django.utils.text import slugify
@@ -11,6 +10,7 @@ from core.models import ObjectType
from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site
from extras.choices import CustomFieldTypeChoices
from extras.models import CustomField, Tag
from users.models import User
from virtualization.models import Cluster, ClusterType, VirtualMachine
@@ -67,7 +67,7 @@ def create_test_user(username='testuser', permissions=None):
"""
Create a User with the given permissions.
"""
user = get_user_model().objects.create_user(username=username)
user = User.objects.create_user(username=username)
if permissions is None:
permissions = ()
for perm_name in permissions:
@@ -107,6 +107,16 @@ def disable_warnings(logger_name):
logger.setLevel(current_level)
@contextmanager
def disable_logging(level=logging.CRITICAL):
"""
Temporarily suppress log messages at or below the specified level (default: critical).
"""
logging.disable(level)
yield
logging.disable(logging.NOTSET)
#
# Custom field testing
#

View File

@@ -8,9 +8,8 @@ from django.test import override_settings
from django.urls import reverse
from django.utils.translation import gettext as _
from core.models import ObjectType
from extras.choices import ObjectChangeActionChoices
from extras.models import ObjectChange
from core.choices import ObjectChangeActionChoices
from core.models import ObjectChange, ObjectType
from netbox.choices import CSVDelimiterChoices, ImportFormatChoices
from netbox.models.features import ChangeLoggingMixin, CustomFieldsMixin
from users.models import ObjectPermission
@@ -62,7 +61,7 @@ class ViewTestCases:
"""
Retrieve a single instance.
"""
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
def test_get_object_anonymous(self):
# Make the request as an unauthenticated user
self.client.logout()
@@ -421,7 +420,7 @@ class ViewTestCases:
"""
Retrieve multiple instances.
"""
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'], LOGIN_REQUIRED=False)
def test_list_objects_anonymous(self):
# Make the request as an unauthenticated user
self.client.logout()
@@ -595,10 +594,10 @@ class ViewTestCases:
# Test GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(self._get_url('import')), 403)
self.assertHttpStatus(self.client.get(self._get_url('bulk_import')), 403)
# Try POST without permission
response = self.client.post(self._get_url('import'), data)
response = self.client.post(self._get_url('bulk_import'), data)
with disable_warnings('django.request'):
self.assertHttpStatus(response, 403)
@@ -621,10 +620,10 @@ class ViewTestCases:
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Try GET with model-level permission
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)
self.assertHttpStatus(self.client.get(self._get_url('bulk_import')), 200)
# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
@@ -650,7 +649,7 @@ class ViewTestCases:
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
self.assertEqual(initial_count, self._get_queryset().count())
reader = csv.DictReader(array, delimiter=',')
@@ -685,7 +684,7 @@ class ViewTestCases:
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
# Attempt to import non-permitted objects
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 200)
self.assertEqual(self._get_queryset().count(), initial_count)
# Update permission constraints
@@ -693,7 +692,7 @@ class ViewTestCases:
obj_perm.save()
# Import permitted objects
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
self.assertHttpStatus(self.client.post(self._get_url('bulk_import'), data), 302)
self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1)
class BulkEditObjectsViewTestCase(ModelViewTestCase):

View File

@@ -144,24 +144,37 @@ class APIPaginationTestCase(APITestCase):
self.assertIsNone(response.data['previous'])
self.assertEqual(len(response.data['results']), page_size)
@override_settings(MAX_PAGE_SIZE=30)
def test_default_page_size_with_small_max_page_size(self):
response = self.client.get(self.url, format='json', **self.header)
page_size = get_config().MAX_PAGE_SIZE
paginate_count = get_config().PAGINATE_COUNT
self.assertLess(page_size, 100, "Default page size not sufficient for data set")
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 100)
self.assertTrue(response.data['next'].endswith(f'?limit={paginate_count}&offset={paginate_count}'))
self.assertIsNone(response.data['previous'])
self.assertEqual(len(response.data['results']), paginate_count)
def test_custom_page_size(self):
response = self.client.get(f'{self.url}?limit=10', format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 100)
self.assertTrue(response.data['next'].endswith(f'?limit=10&offset=10'))
self.assertTrue(response.data['next'].endswith('?limit=10&offset=10'))
self.assertIsNone(response.data['previous'])
self.assertEqual(len(response.data['results']), 10)
@override_settings(MAX_PAGE_SIZE=20)
@override_settings(MAX_PAGE_SIZE=80)
def test_max_page_size(self):
response = self.client.get(f'{self.url}?limit=0', format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 100)
self.assertTrue(response.data['next'].endswith(f'?limit=20&offset=20'))
self.assertTrue(response.data['next'].endswith('?limit=80&offset=80'))
self.assertIsNone(response.data['previous'])
self.assertEqual(len(response.data['results']), 20)
self.assertEqual(len(response.data['results']), 80)
@override_settings(MAX_PAGE_SIZE=0)
def test_max_page_size_disabled(self):

View File

@@ -83,9 +83,9 @@ class CountersTest(TestCase):
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_mptt_child_delete(self):
device1, device2 = Device.objects.all()
device1 = Device.objects.first()
inventory_item1 = InventoryItem.objects.create(device=device1, name='Inventory Item 1')
inventory_item2 = InventoryItem.objects.create(device=device1, name='Inventory Item 2', parent=inventory_item1)
InventoryItem.objects.create(device=device1, name='Inventory Item 2', parent=inventory_item1)
device1.refresh_from_db()
self.assertEqual(device1.inventory_item_count, 2)

View File

@@ -0,0 +1,68 @@
from django.db.backends.postgresql.psycopg_any import NumericRange
from django.test import TestCase
from utilities.data import check_ranges_overlap, ranges_to_string, string_to_ranges
class RangeFunctionsTestCase(TestCase):
def test_check_ranges_overlap(self):
# Non-overlapping ranges
self.assertFalse(
check_ranges_overlap([
NumericRange(9, 19, bounds='(]'), # 10-19
NumericRange(19, 30, bounds='(]'), # 20-29
])
)
self.assertFalse(
check_ranges_overlap([
NumericRange(10, 19, bounds='[]'), # 10-19
NumericRange(20, 29, bounds='[]'), # 20-29
])
)
self.assertFalse(
check_ranges_overlap([
NumericRange(10, 20, bounds='[)'), # 10-19
NumericRange(20, 30, bounds='[)'), # 20-29
])
)
# Overlapping ranges
self.assertTrue(
check_ranges_overlap([
NumericRange(9, 20, bounds='(]'), # 10-20
NumericRange(19, 30, bounds='(]'), # 20-30
])
)
self.assertTrue(
check_ranges_overlap([
NumericRange(10, 20, bounds='[]'), # 10-20
NumericRange(20, 30, bounds='[]'), # 20-30
])
)
self.assertTrue(
check_ranges_overlap([
NumericRange(10, 21, bounds='[)'), # 10-20
NumericRange(20, 31, bounds='[)'), # 10-30
])
)
def test_ranges_to_string(self):
self.assertEqual(
ranges_to_string([
NumericRange(10, 20), # 10-19
NumericRange(30, 40), # 30-39
NumericRange(100, 200), # 100-199
]),
'10-19,30-39,100-199'
)
def test_string_to_ranges(self):
self.assertEqual(
string_to_ranges('10-19, 30-39, 100-199'),
[
NumericRange(10, 19, bounds='[]'), # 10-19
NumericRange(30, 39, bounds='[]'), # 30-39
NumericRange(100, 199, bounds='[]'), # 100-199
]
)

View File

@@ -7,15 +7,16 @@ from taggit.managers import TaggableManager
from dcim.choices import *
from dcim.fields import MACAddressField
from dcim.filtersets import DeviceFilterSet, SiteFilterSet
from dcim.filtersets import DeviceFilterSet, SiteFilterSet, InterfaceFilterSet
from dcim.models import (
Device, DeviceRole, DeviceType, Interface, Manufacturer, Platform, Rack, Region, Site
Device, DeviceRole, DeviceType, Interface, MACAddress, Manufacturer, Platform, Rack, Region, Site
)
from extras.filters import TagFilter
from extras.models import TaggedItem
from ipam.filtersets import ASNFilterSet
from ipam.models import RIR, ASN
from netbox.filtersets import BaseFilterSet
from wireless.choices import WirelessRoleChoices
from utilities.filters import (
MultiValueCharFilter, MultiValueDateFilter, MultiValueDateTimeFilter, MultiValueMACAddressFilter,
MultiValueNumberFilter, MultiValueTimeFilter, TreeNodeMultipleChoiceFilter,
@@ -408,9 +409,9 @@ class DynamicFilterLookupExpressionTest(TestCase):
region.save()
sites = (
Site(name='Site 1', slug='abc-site-1', region=regions[0]),
Site(name='Site 2', slug='def-site-2', region=regions[1]),
Site(name='Site 3', slug='ghi-site-3', region=regions[2]),
Site(name='Site 1', slug='abc-site-1', region=regions[0], status=SiteStatusChoices.STATUS_ACTIVE),
Site(name='Site 2', slug='def-site-2', region=regions[1], status=SiteStatusChoices.STATUS_ACTIVE),
Site(name='Site 3', slug='ghi-site-3', region=regions[2], status=SiteStatusChoices.STATUS_PLANNED),
)
Site.objects.bulk_create(sites)
@@ -426,26 +427,88 @@ class DynamicFilterLookupExpressionTest(TestCase):
Rack.objects.bulk_create(racks)
devices = (
Device(name='Device 1', device_type=device_types[0], role=roles[0], platform=platforms[0], serial='ABC', asset_tag='1001', site=sites[0], rack=racks[0], position=1, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_ACTIVE, local_context_data={"foo": 123}),
Device(name='Device 2', device_type=device_types[1], role=roles[1], platform=platforms[1], serial='DEF', asset_tag='1002', site=sites[1], rack=racks[1], position=2, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_STAGED),
Device(name='Device 3', device_type=device_types[2], role=roles[2], platform=platforms[2], serial='GHI', asset_tag='1003', site=sites[2], rack=racks[2], position=3, face=DeviceFaceChoices.FACE_REAR, status=DeviceStatusChoices.STATUS_FAILED),
Device(
name='Device 1',
device_type=device_types[0],
role=roles[0],
platform=platforms[0],
serial='ABC',
asset_tag='1001',
site=sites[0],
rack=racks[0],
position=1,
face=DeviceFaceChoices.FACE_FRONT,
status=DeviceStatusChoices.STATUS_ACTIVE,
local_context_data={'foo': 123},
),
Device(
name='Device 2',
device_type=device_types[1],
role=roles[1],
platform=platforms[1],
serial='DEF',
asset_tag='1002',
site=sites[1],
rack=racks[1],
position=2,
face=DeviceFaceChoices.FACE_FRONT,
status=DeviceStatusChoices.STATUS_STAGED,
),
Device(
name='Device 3',
device_type=device_types[2],
role=roles[2],
platform=platforms[2],
serial='GHI',
asset_tag='1003',
site=sites[2],
rack=racks[2],
position=3,
face=DeviceFaceChoices.FACE_REAR,
status=DeviceStatusChoices.STATUS_FAILED,
),
)
Device.objects.bulk_create(devices)
mac_addresses = (
MACAddress(mac_address='00-00-00-00-00-01'),
MACAddress(mac_address='aa-00-00-00-00-01'),
MACAddress(mac_address='00-00-00-00-00-02'),
MACAddress(mac_address='bb-00-00-00-00-02'),
MACAddress(mac_address='00-00-00-00-00-03'),
MACAddress(mac_address='cc-00-00-00-00-03'),
)
MACAddress.objects.bulk_create(mac_addresses)
interfaces = (
Interface(device=devices[0], name='Interface 1', mac_address='00-00-00-00-00-01'),
Interface(device=devices[0], name='Interface 2', mac_address='aa-00-00-00-00-01'),
Interface(device=devices[1], name='Interface 3', mac_address='00-00-00-00-00-02'),
Interface(device=devices[1], name='Interface 4', mac_address='bb-00-00-00-00-02'),
Interface(device=devices[2], name='Interface 5', mac_address='00-00-00-00-00-03'),
Interface(device=devices[2], name='Interface 6', mac_address='cc-00-00-00-00-03'),
Interface(device=devices[0], name='Interface 1'),
Interface(device=devices[0], name='Interface 2'),
Interface(device=devices[1], name='Interface 3'),
Interface(device=devices[1], name='Interface 4'),
Interface(device=devices[2], name='Interface 5'),
Interface(device=devices[2], name='Interface 6', rf_role=WirelessRoleChoices.ROLE_AP),
)
Interface.objects.bulk_create(interfaces)
interfaces[0].mac_addresses.set([mac_addresses[0]])
interfaces[1].mac_addresses.set([mac_addresses[1]])
interfaces[2].mac_addresses.set([mac_addresses[2]])
interfaces[3].mac_addresses.set([mac_addresses[3]])
interfaces[4].mac_addresses.set([mac_addresses[4]])
interfaces[5].mac_addresses.set([mac_addresses[5]])
def test_site_name_negation(self):
params = {'name__n': ['Site 1']}
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 2)
def test_site_status_icontains(self):
params = {'status__ic': [SiteStatusChoices.STATUS_ACTIVE]}
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 2)
def test_site_status_icontains_negation(self):
params = {'status__nic': [SiteStatusChoices.STATUS_ACTIVE]}
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 1)
def test_site_slug_icontains(self):
params = {'slug__ic': ['-1']}
self.assertEqual(SiteFilterSet(params, Site.objects.all()).qs.count(), 1)
@@ -553,3 +616,9 @@ class DynamicFilterLookupExpressionTest(TestCase):
def test_device_mac_address_icontains_negation(self):
params = {'mac_address__nic': ['aa:', 'bb']}
self.assertEqual(DeviceFilterSet(params, Device.objects.all()).qs.count(), 1)
def test_interface_rf_role_empty(self):
params = {'rf_role__empty': 'true'}
self.assertEqual(InterfaceFilterSet(params, Interface.objects.all()).qs.count(), 5)
params = {'rf_role__empty': 'false'}
self.assertEqual(InterfaceFilterSet(params, Interface.objects.all()).qs.count(), 1)

View File

@@ -9,22 +9,27 @@ __all__ = (
)
def get_model_urls(app_label, model_name):
def get_model_urls(app_label, model_name, detail=True):
"""
Return a list of URL paths for detail views registered to the given model.
Args:
app_label: App/plugin name
model_name: Model name
detail: If True (default), return only URL views for an individual object.
Otherwise, return only list views.
"""
paths = []
# Retrieve registered views for this model
try:
views = registry['views'][app_label][model_name]
views = [
view for view in registry['views'][app_label][model_name]
if view['detail'] == detail
]
except KeyError:
# No views have been registered for this model
views = []
return []
for config in views:
# Import the view class or function

View File

@@ -1,15 +1,22 @@
from typing import Iterable
from django.conf import settings
from django.contrib.auth.mixins import AccessMixin
from django.core.exceptions import ImproperlyConfigured
from django.urls import reverse
from django.urls.exceptions import NoReverseMatch
from django.utils.http import url_has_allowed_host_and_scheme
from django.utils.translation import gettext_lazy as _
from netbox.plugins import PluginConfig
from netbox.registry import registry
from utilities.relations import get_related_models
from .permissions import resolve_permission
__all__ = (
'ConditionalLoginRequiredMixin',
'ContentTypePermissionRequiredMixin',
'GetRelatedModelsMixin',
'GetReturnURLMixin',
'ObjectPermissionRequiredMixin',
'ViewTab',
@@ -22,10 +29,20 @@ __all__ = (
# View Mixins
#
class ContentTypePermissionRequiredMixin(AccessMixin):
class ConditionalLoginRequiredMixin(AccessMixin):
"""
Similar to Django's LoginRequiredMixin, but enforces authentication only if LOGIN_REQUIRED is True.
"""
def dispatch(self, request, *args, **kwargs):
if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
return self.handle_no_permission()
return super().dispatch(request, *args, **kwargs)
class ContentTypePermissionRequiredMixin(ConditionalLoginRequiredMixin):
"""
Similar to Django's built-in PermissionRequiredMixin, but extended to check model-level permission assignments.
This is related to ObjectPermissionRequiredMixin, except that is does not enforce object-level permissions,
This is related to ObjectPermissionRequiredMixin, except that it does not enforce object-level permissions,
and fits within NetBox's custom permission enforcement system.
additional_permissions: An optional iterable of statically declared permissions to evaluate in addition to those
@@ -58,7 +75,7 @@ class ContentTypePermissionRequiredMixin(AccessMixin):
return super().dispatch(request, *args, **kwargs)
class ObjectPermissionRequiredMixin(AccessMixin):
class ObjectPermissionRequiredMixin(ConditionalLoginRequiredMixin):
"""
Similar to Django's built-in PermissionRequiredMixin, but extended to check for both model-level and object-level
permission assignments. If the user has only object-level permissions assigned, the view's queryset is filtered
@@ -119,7 +136,7 @@ class GetReturnURLMixin:
# First, see if `return_url` was specified as a query parameter or form data. Use this URL only if it's
# considered safe.
return_url = request.GET.get('return_url') or request.POST.get('return_url')
if return_url and return_url.startswith('/'):
if return_url and url_has_allowed_host_and_scheme(return_url, allowed_hosts=None):
return return_url
# Next, check if the object being modified (if any) has an absolute URL.
@@ -142,6 +159,46 @@ class GetReturnURLMixin:
return reverse('home')
class GetRelatedModelsMixin:
"""
Provides logic for collecting all related models for the currently viewed model.
"""
def get_related_models(self, request, instance, omit=[], extra=[]):
"""
Get related models of the view's `queryset` model without those listed in `omit`. Will be sorted alphabetical.
Args:
request: Current request being processed.
instance: The instance related models should be looked up for. A list of instances can be passed to match
related objects in this list (e.g. to find sites of a region including child regions).
omit: Remove relationships to these models from the result. Needs to be passed, if related models don't
provide a `_list` view.
extra: Add extra models to the list of automatically determined related models. Can be used to add indirect
relationships.
"""
model = self.queryset.model
related = filter(
lambda m: m[0] is not model and m[0] not in omit,
get_related_models(model, False)
)
related_models = [
(
model.objects.restrict(request.user, 'view').filter(**(
{f'{field}__in': instance}
if isinstance(instance, Iterable)
else {field: instance}
)),
f'{field}_id'
)
for model, field in related
]
related_models.extend(extra)
return sorted(related_models, key=lambda x: x[0].model._meta.verbose_name.lower())
class ViewTab:
"""
ViewTabs are used for navigation among multiple object-specific views, such as the changelog or journal for
@@ -215,7 +272,7 @@ def get_viewname(model, action=None, rest_api=False):
return viewname
def register_model_view(model, name='', path=None, kwargs=None):
def register_model_view(model, name='', path=None, detail=True, kwargs=None):
"""
This decorator can be used to "attach" a view to any model in NetBox. This is typically used to inject
additional tabs within a model's detail view. For example, to add a custom tab to NetBox's dcim.Site model:
@@ -232,6 +289,7 @@ def register_model_view(model, name='', path=None, kwargs=None):
name: The string used to form the view's name for URL resolution (e.g. via `reverse()`). This will be appended
to the name of the base view for the model using an underscore. If blank, the model name will be used.
path: The URL path by which the view can be reached (optional). If not provided, `name` will be used.
detail: True if the path applied to an individual object; False if it attaches to the base (list) path.
kwargs: A dictionary of keyword arguments for the view to include when registering its URL path (optional).
"""
def _wrapper(cls):
@@ -244,7 +302,8 @@ def register_model_view(model, name='', path=None, kwargs=None):
registry['views'][app_label][model_name].append({
'name': name,
'view': cls,
'path': path or name,
'path': path if path is not None else name,
'detail': detail,
'kwargs': kwargs or {},
})