diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index e1012212d..54dc5ca8c 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -30,10 +30,9 @@ about: Report a reproducible bug in the current release of NetBox library such as pynetbox. --> ### Steps to Reproduce -1. Disable any installed plugins by commenting out the `PLUGINS` setting in - `configuration.py`. -2. -3. +1. +2. +3. ### Expected Behavior diff --git a/.gitignore b/.gitignore index 485b46d59..95e4ff702 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ /netbox/static /venv/ /*.sh +local_requirements.txt !upgrade.sh fabfile.py gunicorn.py diff --git a/base_requirements.txt b/base_requirements.txt index caf7ba5f3..a57e88604 100644 --- a/base_requirements.txt +++ b/base_requirements.txt @@ -42,10 +42,6 @@ django-tables2 # https://github.com/alex/django-taggit django-taggit -# A Django REST Framework serializer which represents tags -# https://github.com/glemmaPaul/django-taggit-serializer -django-taggit-serializer - # A Django field for representing time zones # https://github.com/mfogel/django-timezone-field/ django-timezone-field diff --git a/docs/additional-features/custom-links.md b/docs/additional-features/custom-links.md index 7c96eba8b..56d67a7be 100644 --- a/docs/additional-features/custom-links.md +++ b/docs/additional-features/custom-links.md @@ -24,7 +24,7 @@ Only links which render with non-empty text are included on the page. You can em For example, if you only want to display a link for active devices, you could set the link text to ``` -{% if obj.status == 1 %}View NMS{% endif %} +{% if obj.status == 'active' %}View NMS{% endif %} ``` The link will not appear when viewing a device with any status other than "active." diff --git a/docs/additional-features/prometheus-metrics.md b/docs/additional-features/prometheus-metrics.md index 0aa944b74..1429fb0a7 100644 --- a/docs/additional-features/prometheus-metrics.md +++ b/docs/additional-features/prometheus-metrics.md @@ -32,3 +32,7 @@ This can be setup by first creating a shared directory and then adding this line ``` environment=prometheus_multiproc_dir=/tmp/prometheus_metrics ``` + +#### Accuracy + +If having accurate long-term metrics in a multiprocess environment is important to you then it's recommended you use the `uwsgi` library instead of `gunicorn`. The issue lies in the way `gunicorn` tracks worker processes (vs `uwsgi`) which helps manage the metrics files created by the above configurations. If you're using Netbox with gunicorn in a containerized enviroment following the one-process-per-container methodology, then you will likely not need to change to `uwsgi`. More details can be found in [issue #3779](https://github.com/netbox-community/netbox/issues/3779#issuecomment-590547562). \ No newline at end of file diff --git a/docs/additional-features/reports.md b/docs/additional-features/reports.md index 74137ebb8..bcd39fc3c 100644 --- a/docs/additional-features/reports.md +++ b/docs/additional-features/reports.md @@ -33,7 +33,6 @@ Within each report class, we'll create a number of test methods to execute our r ``` from dcim.choices import DeviceStatusChoices -from dcim.constants import CONNECTION_STATUS_PLANNED from dcim.models import ConsolePort, Device, PowerPort from extras.reports import Report @@ -53,7 +52,7 @@ class DeviceConnectionsReport(Report): console_port.device, "No console connection defined for {}".format(console_port.name) ) - elif console_port.connection_status == CONNECTION_STATUS_PLANNED: + elif not console_port.connection_status: self.log_warning( console_port.device, "Console connection for {} marked as planned".format(console_port.name) @@ -69,7 +68,7 @@ class DeviceConnectionsReport(Report): for power_port in PowerPort.objects.filter(device=device): if power_port.connected_endpoint is not None: connected_ports += 1 - if power_port.connection_status == CONNECTION_STATUS_PLANNED: + if not power_port.connection_status: self.log_warning( device, "Power connection for {} marked as planned".format(power_port.name) diff --git a/docs/administration/permissions.md b/docs/administration/permissions.md new file mode 100644 index 000000000..7e47db0d9 --- /dev/null +++ b/docs/administration/permissions.md @@ -0,0 +1,43 @@ +# Permissions + +NetBox v2.9 introduced a new object-based permissions framework, which replace's Django's built-in permission model. Object-based permissions allow for the assignment of permissions to an arbitrary subset of objects of a certain type, rather than only by type of object. For example, it is possible to grant a user permission to view only sites within a particular region, or to modify only VLANs with a numeric ID within a certain range. + +{!docs/models/users/objectpermission.md!} + +### Example Constraint Definitions + +| Query Filter | Permission Constraints | +| ------------ | --------------------- | +| `filter(status='active')` | `{"status": "active"}` | +| `filter(status='active', role='testing')` | `{"status": "active", "role": "testing"}` | +| `filter(status__in=['planned', 'reserved'])` | `{"status__in": ["planned", "reserved"]}` | +| `filter(name__startswith('Foo')` | `{"name__startswith": "Foo"}` | +| `filter(vid__gte=100, vid__lt=200)` | `{"vid__gte": 100, "vid__lt": 200}` | + +## Permissions Enforcement + +### Viewing Objects + +Object-based permissions work by filtering the database query generated by a user's request to restrict the set of objects returned. When a request is received, NetBox first determines whether the user is authenticated and has been granted to perform the requested action. For example, if the requested URL is `/dcim/devices/`, NetBox will check for the `dcim.view_device` permission. If the user has not been assigned this permission (either directly or via a group assignment), NetBox will return a 403 (forbidden) HTTP response. + +If the permission has been granted, NetBox will compile any specified constraints for the model and action. For example, suppose two permissions have been assigned to the user granting view access to the device model, with the following constraints: + +```json +[ + {"site__name__in": ["NYC1", "NYC2"]}, + {"status": "offline", "tenant__isnull": true} +] +``` + +This grants the user access to view any device that is in NYC1 or NYC2, **or** which has a status of "offline" and has no tenant assigned. These constraints will result in the following ORM query: + +```no-highlight +Site.objects.filter( + Q(site__name__in=['NYC1', 'NYC2']), + Q(status='active', tenant__isnull=True) +) +``` + +### Creating and Modifying Objects + +The same sort of logic is in play when a user attempts to create or modify an object in NetBox, with a twist. Once validation has completed, NetBox starts an atomic database transaction to facilitate the change, and the object is created or saved normally. Next, still within the transaction, NetBox issues a second query to retrieve the newly created/updated object, filtering the restricted queryset with the object's primary key. If this query fails to return the object, NetBox knows that the new revision does not match the constraints imposed by the permission. The transaction is then aborted, and the database is left in its original state. diff --git a/docs/api/authentication.md b/docs/api/authentication.md index 8e38c4de9..e8e6ddc96 100644 --- a/docs/api/authentication.md +++ b/docs/api/authentication.md @@ -2,18 +2,7 @@ The NetBox API employs token-based authentication. For convenience, cookie authentication can also be used when navigating the browsable API. -## Tokens - -A token is a unique identifier that identifies a user to the API. Each user in NetBox may have one or more tokens which he or she can use to authenticate to the API. To create a token, navigate to the API tokens page at `/user/api-tokens/`. - -!!! note - The creation and modification of API tokens can be restricted per user by an administrator. If you don't see an option to create an API token, ask an administrator to grant you access. - -Each token contains a 160-bit key represented as 40 hexadecimal characters. When creating a token, you'll typically leave the key field blank so that a random key will be automatically generated. However, NetBox allows you to specify a key in case you need to restore a previously deleted token to operation. - -By default, a token can be used for all operations available via the API. Deselecting the "write enabled" option will restrict API requests made with the token to read operations (e.g. GET) only. - -Additionally, a token can be set to expire at a specific time. This can be useful if an external client needs to be granted temporary access to NetBox. +{!docs/models/users/token.md!} ## Authenticating to the API diff --git a/docs/api/examples.md b/docs/api/examples.md index 1906d0db9..f4348907f 100644 --- a/docs/api/examples.md +++ b/docs/api/examples.md @@ -145,3 +145,18 @@ $ curl -v -X DELETE -H "Authorization: Token d2f763479f703d80de0ec15254237bc651f ``` The response to a successful `DELETE` request will have code 204 (No Content); the body of the response will be empty. + + +## Bulk Object Creation + +The REST API supports the creation of multiple objects of the same type using a single `POST` request. For example, to create multiple devices: + +``` +curl -X POST -H "Authorization: Token " -H "Content-Type: application/json" -H "Accept: application/json; indent=4" http://localhost:8000/api/dcim/devices/ --data '[ +{"name": "device1", "device_type": 24, "device_role": 17, "site": 6}, +{"name": "device2", "device_type": 24, "device_role": 17, "site": 6}, +{"name": "device3", "device_type": 24, "device_role": 17, "site": 6}, +]' +``` + +Bulk creation is all-or-none: If any of the creations fails, the entire operation is rolled back. A successful response returns an HTTP code 201 and the body of the response will be a list/array of the objects created. \ No newline at end of file diff --git a/docs/configuration/optional-settings.md b/docs/configuration/optional-settings.md index 3f2b29b87..8ef2b4b21 100644 --- a/docs/configuration/optional-settings.md +++ b/docs/configuration/optional-settings.md @@ -13,6 +13,14 @@ ADMINS = [ --- +## ALLOWED_URL_SCHEMES + +Default: `('file', 'ftp', 'ftps', 'http', 'https', 'irc', 'mailto', 'sftp', 'ssh', 'tel', 'telnet', 'tftp', 'vnc', 'xmpp')` + +A list of permitted URL schemes referenced when rendering links within NetBox. Note that only the schemes specified in this list will be accepted: If adding your own, be sure to replicate the entire default list as well (excluding those schemes which are not desirable). + +--- + ## BANNER_TOP ## BANNER_BOTTOM @@ -86,7 +94,12 @@ CORS_ORIGIN_WHITELIST = [ Default: False -This setting enables debugging. This should be done only during development or troubleshooting. Never enable debugging on a production system, as it can expose sensitive data to unauthenticated users. +This setting enables debugging. This should be done only during development or troubleshooting. Note that only clients +which access NetBox from a recognized [internal IP address](#internal_ips) will see debugging tools in the user +interface. + +!!! warning + Never enable debugging on a production system, as it can expose sensitive data to unauthenticated users. --- @@ -108,16 +121,20 @@ The file path to NetBox's documentation. This is used when presenting context-se ## EMAIL -In order to send email, NetBox needs an email server configured. The following items can be defined within the `EMAIL` setting: +In order to send email, NetBox needs an email server configured. The following items can be defined within the `EMAIL` configuration parameter: -* SERVER - Host name or IP address of the email server (use `localhost` if running locally) -* PORT - TCP port to use for the connection (default: 25) -* USERNAME - Username with which to authenticate -* PASSSWORD - Password with which to authenticate -* TIMEOUT - Amount of time to wait for a connection (seconds) -* FROM_EMAIL - Sender address for emails sent by NetBox +* `SERVER` - Host name or IP address of the email server (use `localhost` if running locally) +* `PORT` - TCP port to use for the connection (default: `25`) +* `USERNAME` - Username with which to authenticate +* `PASSSWORD` - Password with which to authenticate +* `USE_SSL` - Use SSL when connecting to the server (default: `False`). Mutually exclusive with `USE_TLS`. +* `USE_TLS` - Use TLS when connecting to the server (default: `False`). Mutually exclusive with `USE_SSL`. +* `SSL_CERTFILE` - Path to the PEM-formatted SSL certificate file (optional) +* `SSL_KEYFILE` - Path to the PEM-formatted SSL private key file (optional) +* `TIMEOUT` - Amount of time to wait for a connection, in seconds (default: `10`) +* `FROM_EMAIL` - Sender address for emails sent by NetBox (default: `root@localhost`) -Email is sent from NetBox only for critical events. If you would like to test the email server configuration please use the django function [send_mail()](https://docs.djangoproject.com/en/stable/topics/email/#send-mail): +Email is sent from NetBox only for critical events or if configured for [logging](#logging). If you would like to test the email server configuration please use the django function [send_mail()](https://docs.djangoproject.com/en/stable/topics/email/#send-mail): ``` # python ./manage.py nbshell @@ -180,6 +197,16 @@ HTTP_PROXIES = { --- +## INTERNAL_IPS + +Default: `('127.0.0.1', '::1',)` + +A list of IP addresses recognized as internal to the system, used to control the display of debugging output. For +example, the debugging toolbar will be viewable only when a client is accessing NetBox from one of the listed IP +addresses (and [`DEBUG`](#debug) is true). + +--- + ## LOGGING By default, all messages of INFO severity or higher will be logged to the console. Additionally, if `DEBUG` is False and email access has been configured, ERROR and CRITICAL messages will be emailed to the users defined in `ADMINS`. @@ -365,9 +392,12 @@ NetBox can be configured to support remote user authentication by inferring user ## REMOTE_AUTH_BACKEND -Default: `'utilities.auth_backends.RemoteUserBackend'` +Default: `'netbox.authentication.RemoteUserBackend'` -Python path to the custom [Django authentication backend](https://docs.djangoproject.com/en/stable/topics/auth/customizing/) to use for external user authentication, if not using NetBox's built-in backend. (Requires `REMOTE_AUTH_ENABLED`.) +Python path to the custom [Django authentication backend](https://docs.djangoproject.com/en/stable/topics/auth/customizing/) to use for external user authentication. NetBox provides two built-in backends (listed below), though backends may also be provided via other packages. + +* `netbox.authentication.RemoteUserBackend` +* `netbox.authentication.LDAPBackend` --- @@ -381,7 +411,7 @@ When remote user authentication is in use, this is the name of the HTTP header w ## REMOTE_AUTH_AUTO_CREATE_USER -Default: `True` +Default: `False` If true, NetBox will automatically create local accounts for users authenticated via a remote service. (Requires `REMOTE_AUTH_ENABLED`.) @@ -397,9 +427,9 @@ The list of groups to assign a new user account when created using remote authen ## REMOTE_AUTH_DEFAULT_PERMISSIONS -Default: `[]` (Empty list) +Default: `{}` (Empty dictionary) -The list of permissions to assign a new user account when created using remote authentication. (Requires `REMOTE_AUTH_ENABLED`.) +A mapping of permissions to assign a new user account when created using remote authentication. Each key in the dictionary should be set to a dictionary of the attributes to be applied to the permission, or `None` to allow all objects. (Requires `REMOTE_AUTH_ENABLED`.) --- diff --git a/docs/development/utility-views.md b/docs/development/utility-views.md index a6e50f71e..3b9c1053d 100644 --- a/docs/development/utility-views.md +++ b/docs/development/utility-views.md @@ -4,6 +4,10 @@ Utility views are reusable views that handle common CRUD tasks, such as listing ## Individual Views +### ObjectView + +Retrieve and display a single object. + ### ObjectListView Generates a paginated table of objects from a given queryset, which may optionally be filtered. diff --git a/docs/index.md b/docs/index.md index 3880c9d07..ee7f77f69 100644 --- a/docs/index.md +++ b/docs/index.md @@ -49,7 +49,7 @@ NetBox is built on the [Django](https://djangoproject.com/) Python framework and | HTTP service | nginx or Apache | | WSGI service | gunicorn or uWSGI | | Application | Django/Python | -| Database | PostgreSQL 9.4+ | +| Database | PostgreSQL 9.6+ | | Task queuing | Redis/django-rq | | Live device access | NAPALM | diff --git a/docs/installation/1-postgresql.md b/docs/installation/1-postgresql.md index afe3a51d2..933e32edc 100644 --- a/docs/installation/1-postgresql.md +++ b/docs/installation/1-postgresql.md @@ -3,7 +3,7 @@ This section entails the installation and configuration of a local PostgreSQL database. If you already have a PostgreSQL database service in place, skip to [the next section](2-redis.md). !!! warning - NetBox requires PostgreSQL 9.4 or higher. Please note that MySQL and other relational databases are **not** supported. + NetBox requires PostgreSQL 9.6 or higher. Please note that MySQL and other relational databases are **not** supported. The installation instructions provided here have been tested to work on Ubuntu 18.04 and CentOS 7.5. The particular commands needed to install dependencies on other distributions may vary significantly. Unfortunately, this is outside the control of the NetBox maintainers. Please consult your distribution's documentation for assistance with any errors. @@ -51,7 +51,7 @@ At a minimum, we need to create a database for NetBox and assign it a username a ```no-highlight # sudo -u postgres psql -psql (9.4.5) +psql (10.10) Type "help" for help. postgres=# CREATE DATABASE netbox; diff --git a/docs/installation/3-netbox.md b/docs/installation/3-netbox.md index 5237e617e..c583d08fe 100644 --- a/docs/installation/3-netbox.md +++ b/docs/installation/3-netbox.md @@ -78,7 +78,8 @@ Create a system user account named `netbox`. We'll configure the WSGI and HTTP s CentOS users may need to create the `netbox` group first. ``` -# adduser --system --group netbox +# groupadd --system netbox +# adduser --system --gid netbox netbox # chown --recursive netbox /opt/netbox/netbox/media/ ``` diff --git a/docs/installation/5-ldap.md b/docs/installation/5-ldap.md index 2fd88b841..bb1300c08 100644 --- a/docs/installation/5-ldap.md +++ b/docs/installation/5-ldap.md @@ -36,7 +36,13 @@ Once installed, add the package to `local_requirements.txt` to ensure it is re-i ## Configuration -Create a file in the same directory as `configuration.py` (typically `netbox/netbox/`) named `ldap_config.py`. Define all of the parameters required below in `ldap_config.py`. Complete documentation of all `django-auth-ldap` configuration options is included in the project's [official documentation](http://django-auth-ldap.readthedocs.io/). +First, enable the LDAP authentication backend in `configuration.py`. (Be sure to overwrite this definition if it is already set to `RemoteUserBackend`.) + +```python +REMOTE_AUTH_BACKEND = 'netbox.authentication.LDAPBackend' +``` + +Next, create a file in the same directory as `configuration.py` (typically `netbox/netbox/`) named `ldap_config.py`. Define all of the parameters required below in `ldap_config.py`. Complete documentation of all `django-auth-ldap` configuration options is included in the project's [official documentation](http://django-auth-ldap.readthedocs.io/). ### General Server Configuration @@ -145,7 +151,8 @@ logfile = "/opt/netbox/logs/django-ldap-debug.log" my_logger = logging.getLogger('django_auth_ldap') my_logger.setLevel(logging.DEBUG) handler = logging.handlers.RotatingFileHandler( - logfile, maxBytes=1024 * 500, backupCount=5) + logfile, maxBytes=1024 * 500, backupCount=5 +) my_logger.addHandler(handler) ``` diff --git a/docs/models/users/objectpermission.md b/docs/models/users/objectpermission.md new file mode 100644 index 000000000..80313fc0b --- /dev/null +++ b/docs/models/users/objectpermission.md @@ -0,0 +1,36 @@ +# Object Permissions + +Assigning a permission in NetBox entails defining a relationship among several components: + +* Object type(s) - One or more types of object in NetBox +* User(s) - One or more users or groups of users +* Actions - The actions that can be performed (view, add, change, and/or delete) +* Constraints - An arbitrary filter used to limit the granted action(s) to a specific subset of objects + +At a minimum, a permission assignment must specify one object type, one user or group, and one action. The specification of constraints is optional: A permission without any constraints specified will apply to all instances of the selected model(s). + +## Actions + +There are four core actions that can be permitted for each type of object within NetBox, roughly analogous to the CRUD convention (create, read, update, and delete): + +* View - Retrieve an object from the database +* Add - Create a new object +* Change - Modify an existing object +* Delete - Delete an existing object + +Some models introduce additional permissions that can be granted to allow other actions. For example, the `napalm_read` permission on the device model allows a user to execute NAPALM queries on a device via NetBox's REST API. These can be specified when granting a permission in the "additional actions" field. + +## Constraints + +Constraints are defined as a JSON object representing a [Django query filter](https://docs.djangoproject.com/en/stable/ref/models/querysets/#field-lookups). This is the same syntax that you would pass to the QuerySet `filter()` method when performing a query using the Django ORM. As with query filters, double underscores can be used to traverse related objects or invoke lookup expressions. Some example queries and their corresponding definitions are shown below. + +All constraints defined on a permission are applied with a logic AND. For example, suppose you assign a permission for the site model with the following constraints. + +```json +{ + "status": "active", + "region__name": "Americas" +} +``` + +The permission will grant access only to sites which have a status of "active" **and** which are assigned to the "Americas" region. To achieve a logical OR with a different set of constraints, simply create another permission assignment for the same model and user/group. diff --git a/docs/models/users/token.md b/docs/models/users/token.md new file mode 100644 index 000000000..bbeb2284b --- /dev/null +++ b/docs/models/users/token.md @@ -0,0 +1,12 @@ +## Tokens + +A token is a unique identifier that identifies a user to the API. Each user in NetBox may have one or more tokens which he or she can use to authenticate to the API. To create a token, navigate to the API tokens page at `/user/api-tokens/`. + +!!! note + The creation and modification of API tokens can be restricted per user by an administrator. If you don't see an option to create an API token, ask an administrator to grant you access. + +Each token contains a 160-bit key represented as 40 hexadecimal characters. When creating a token, you'll typically leave the key field blank so that a random key will be automatically generated. However, NetBox allows you to specify a key in case you need to restore a previously deleted token to operation. + +By default, a token can be used for all operations available via the API. Deselecting the "write enabled" option will restrict API requests made with the token to read operations (e.g. GET) only. + +Additionally, a token can be set to expire at a specific time. This can be useful if an external client needs to be granted temporary access to NetBox. diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 364b2cd9d..f314c5371 120000 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -1 +1 @@ -version-2.8.md \ No newline at end of file +version-2.9.md \ No newline at end of file diff --git a/docs/release-notes/version-2.8.md b/docs/release-notes/version-2.8.md index e75bf4ab9..ca264806b 100644 --- a/docs/release-notes/version-2.8.md +++ b/docs/release-notes/version-2.8.md @@ -1,5 +1,93 @@ # NetBox v2.8 +## v2.8.7 (FUTURE) + +### Bug Fixes + +* [#4766](https://github.com/netbox-community/netbox/issues/4766) - Fix redirect after login when `next` is not specified +* [#4772](https://github.com/netbox-community/netbox/issues/4772) - Fix "brief" format for the secrets REST API endpoint +* [#4775](https://github.com/netbox-community/netbox/issues/4775) - Allow selecting an alternate device type when creating component templates + +--- + +## v2.8.6 (2020-06-15) + +### Enhancements + +* [#4698](https://github.com/netbox-community/netbox/issues/4698) - Improve display of template code for object in admin UI +* [#4717](https://github.com/netbox-community/netbox/issues/4717) - Introduce `ALLOWED_URL_SCHEMES` configuration parameter to mitigate dangerous hyperlinks +* [#4744](https://github.com/netbox-community/netbox/issues/4744) - Hide "IP addresses" tab when viewing a container prefix +* [#4755](https://github.com/netbox-community/netbox/issues/4755) - Enable creation of rack reservations directly from navigation menu +* [#4761](https://github.com/netbox-community/netbox/issues/4761) - Enable tag assignment during bulk creation of IP addresses + +### Bug Fixes + +* [#4674](https://github.com/netbox-community/netbox/issues/4674) - Fix API definition for available prefix and IP address endpoints +* [#4702](https://github.com/netbox-community/netbox/issues/4702) - Catch IntegrityError exception when adding a non-unique secret +* [#4707](https://github.com/netbox-community/netbox/issues/4707) - Fix `prefix_count` population on VLAN API serializer +* [#4710](https://github.com/netbox-community/netbox/issues/4710) - Fix merging of form fields among custom scripts +* [#4725](https://github.com/netbox-community/netbox/issues/4725) - Fix "brief" rendering of various REST API endpoints +* [#4736](https://github.com/netbox-community/netbox/issues/4736) - Add cable trace endpoints for pass-through ports +* [#4737](https://github.com/netbox-community/netbox/issues/4737) - Fix display of role labels in virtual machines table +* [#4743](https://github.com/netbox-community/netbox/issues/4743) - Allow users to create "next available" IPs without needing permission to create prefixes +* [#4756](https://github.com/netbox-community/netbox/issues/4756) - Filter parent group by site when creating rack groups +* [#4760](https://github.com/netbox-community/netbox/issues/4760) - Enable power port template assignment when bulk editing power outlet templates + +--- + +## v2.8.5 (2020-05-26) + +**Note:** The minimum required version of PostgreSQL is now 9.6. + +### Enhancements + +* [#4650](https://github.com/netbox-community/netbox/issues/4650) - Expose `INTERNAL_IPS` configuration parameter +* [#4651](https://github.com/netbox-community/netbox/issues/4651) - Add `csrf_token` context for plugin templates +* [#4652](https://github.com/netbox-community/netbox/issues/4652) - Add permissions context for plugin templates +* [#4665](https://github.com/netbox-community/netbox/issues/4665) - Add NEMA L14 and L21 power port/outlet types +* [#4672](https://github.com/netbox-community/netbox/issues/4672) - Set default color for rack and devices roles + +### Bug Fixes + +* [#3304](https://github.com/netbox-community/netbox/issues/3304) - Fix caching invalidation issue related to device/virtual machine primary IP addresses +* [#4525](https://github.com/netbox-community/netbox/issues/4525) - Allow passing initial data to custom script MultiObjectVar +* [#4644](https://github.com/netbox-community/netbox/issues/4644) - Fix ordering of services table by parent +* [#4646](https://github.com/netbox-community/netbox/issues/4646) - Correct UI link for reports with custom name +* [#4647](https://github.com/netbox-community/netbox/issues/4647) - Fix caching invalidation issue related to assigning new IP addresses to interfaces +* [#4648](https://github.com/netbox-community/netbox/issues/4648) - Fix bulk CSV import of child devices +* [#4649](https://github.com/netbox-community/netbox/issues/4649) - Fix interface assignment for bulk-imported IP addresses +* [#4676](https://github.com/netbox-community/netbox/issues/4676) - Set default value of `REMOTE_AUTH_AUTO_CREATE_USER` as `False` in docs +* [#4684](https://github.com/netbox-community/netbox/issues/4684) - Respect `comments` field when importing device type in YAML/JSON format + +--- + +## v2.8.4 (2020-05-13) + +### Enhancements + +* [#4632](https://github.com/netbox-community/netbox/issues/4632) - Extend email configuration parameters to support SSL/TLS + +### Bug Fixes + +* [#4598](https://github.com/netbox-community/netbox/issues/4598) - Display error message when invalid cable length is specified +* [#4604](https://github.com/netbox-community/netbox/issues/4604) - Multi-position rear ports may only be connected to other rear ports +* [#4607](https://github.com/netbox-community/netbox/issues/4607) - Missing Contextual help for API Tokens +* [#4613](https://github.com/netbox-community/netbox/issues/4613) - Fix tag assignment on config contexts (regression from #4527) +* [#4617](https://github.com/netbox-community/netbox/issues/4617) - Restore IP prefix depth notation in list view +* [#4629](https://github.com/netbox-community/netbox/issues/4629) - Replicate assigned interface when cloning IP addresses +* [#4633](https://github.com/netbox-community/netbox/issues/4633) - Bump django-rq to v2.3.2 to fix ImportError with rq 1.4.0 +* [#4634](https://github.com/netbox-community/netbox/issues/4634) - Inventory Item List view exception caused by incorrect accessor definition + +--- + +## v2.8.3 (2020-05-06) + +### Bug Fixes + +* [#4593](https://github.com/netbox-community/netbox/issues/4593) - Fix AttributeError exception when viewing object lists as a non-authenticated user + +--- + ## v2.8.2 (2020-05-06) ### Enhancements diff --git a/docs/release-notes/version-2.9.md b/docs/release-notes/version-2.9.md index 520fb2187..d0caea7ad 100644 --- a/docs/release-notes/version-2.9.md +++ b/docs/release-notes/version-2.9.md @@ -1,7 +1,43 @@ -# Netbox v2.9 +# NetBox v2.9 ## v2.9.0 (FUTURE) +### New Features + +#### Object-Based Permissions ([#554](https://github.com/netbox-community/netbox/issues/554)) + +NetBox v2.9 replaces Django's built-in permissions framework with one that supports object-based assignment of permissions using arbitrary constraints. When granting a user or group to perform a certain action on one or more types of objects, an administrator can optionally specify a set of constraints. The permission will apply only to objects which match the specified constraints. For example, assigning permission to modify devices with the constraint `{"tenant__group__name": "Customers"}` would grant the permission only for devices assigned to a tenant belonging to the "Customers" group. + ### Enhancements +* [#3703](https://github.com/netbox-community/netbox/issues/3703) - Tags must be created administratively before being assigned to an object * [#4573](https://github.com/netbox-community/netbox/issues/4573) - Support plugins as a delivery mechanism for reports and custom scripts +* [#4615](https://github.com/netbox-community/netbox/issues/4615) - Add `label` field for all device components +* [#4742](https://github.com/netbox-community/netbox/issues/4742) - Add tagging for cables, power panels, and rack reservations + +### Configuration Changes + +* If in use, LDAP authentication must be enabled by setting `REMOTE_AUTH_BACKEND` to `'netbox.authentication.LDAPBackend'`. (LDAP configuration parameters in `ldap_config.py` remain unchanged.) +* `REMOTE_AUTH_DEFAULT_PERMISSIONS` now takes a dictionary rather than a list. This is a mapping of permission names to a dictionary of constraining attributes, or `None`. For example, `['dcim.add_site', 'dcim.change_site']` would become `{'dcim.add_site': None, 'dcim.change_site': None}`. + +### REST API Changes + +* The count of `tagged_items` is no longer included when viewing the tags list when `brief` is passed. +* The assignment of tags to an object is now achieved in the same manner as specifying any other related device. The `tags` field accepts a list of JSON objects each matching a desired tag. (Alternatively, a list of numeric primary keys corresponding to tags may be passed instead.) For example: + +```json +"tags": [ + {"name": "First Tag"}, + {"name": "Second Tag"} +] +``` + +* The `tags` field of an object now includes a more complete representation of each tag, rather than just its name. +* A `label` field has been added to all device components and component templates. + +### Other Changes + +* The `secrets.activate_userkey` permission no longer exists. Instead, `secrets.change_userkey` is checked to determine whether a user has the ability to activate a UserKey. +* The `users.delete_token` permission is no longer enforced. All users are permitted to delete their own API tokens. +* Dropped backward compatibility for the `webhooks` Redis queue configuration (use `tasks` instead). +* Dropped backward compatibility for the `/admin/webhook-backend-status` URL (moved to `/admin/background-tasks/`). diff --git a/mkdocs.yml b/mkdocs.yml index 5d4636001..f73438eab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,7 @@ nav: - Using Plugins: 'plugins/index.md' - Developing Plugins: 'plugins/development.md' - Administration: + - Permissions: 'administration/permissions.md' - Replicating NetBox: 'administration/replicating-netbox.md' - NetBox Shell: 'administration/netbox-shell.md' - API: diff --git a/netbox/circuits/api/serializers.py b/netbox/circuits/api/serializers.py index 6bac48a59..e8171e2fb 100644 --- a/netbox/circuits/api/serializers.py +++ b/netbox/circuits/api/serializers.py @@ -1,11 +1,11 @@ from rest_framework import serializers -from taggit_serializer.serializers import TaggitSerializer, TagListSerializerField from circuits.choices import CircuitStatusChoices from circuits.models import Provider, Circuit, CircuitTermination, CircuitType from dcim.api.nested_serializers import NestedCableSerializer, NestedInterfaceSerializer, NestedSiteSerializer from dcim.api.serializers import ConnectedEndpointSerializer from extras.api.customfields import CustomFieldModelSerializer +from extras.api.serializers import TaggedObjectSerializer from tenancy.api.nested_serializers import NestedTenantSerializer from utilities.api import ChoiceField, ValidatedModelSerializer, WritableNestedSerializer from .nested_serializers import * @@ -15,8 +15,7 @@ from .nested_serializers import * # Providers # -class ProviderSerializer(TaggitSerializer, CustomFieldModelSerializer): - tags = TagListSerializerField(required=False) +class ProviderSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): circuit_count = serializers.IntegerField(read_only=True) class Meta: @@ -49,14 +48,13 @@ class CircuitCircuitTerminationSerializer(WritableNestedSerializer): fields = ['id', 'url', 'site', 'connected_endpoint', 'port_speed', 'upstream_speed', 'xconnect_id'] -class CircuitSerializer(TaggitSerializer, CustomFieldModelSerializer): +class CircuitSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): provider = NestedProviderSerializer() status = ChoiceField(choices=CircuitStatusChoices, required=False) type = NestedCircuitTypeSerializer() tenant = NestedTenantSerializer(required=False, allow_null=True) termination_a = CircuitCircuitTerminationSerializer(read_only=True) termination_z = CircuitCircuitTerminationSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = Circuit diff --git a/netbox/circuits/filters.py b/netbox/circuits/filters.py index 206dcc305..a81d6acca 100644 --- a/netbox/circuits/filters.py +++ b/netbox/circuits/filters.py @@ -24,13 +24,13 @@ class ProviderFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilte label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='circuits__terminations__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='circuits__terminations__site__region', lookup_expr='in', to_field_name='slug', @@ -38,12 +38,12 @@ class ProviderFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilte ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='circuits__terminations__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site', ) site = django_filters.ModelMultipleChoiceFilter( field_name='circuits__terminations__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) @@ -78,22 +78,22 @@ class CircuitFilterSet(BaseFilterSet, CustomFieldFilterSet, TenancyFilterSet, Cr label='Search', ) provider_id = django_filters.ModelMultipleChoiceFilter( - queryset=Provider.objects.all(), + queryset=Provider.objects.unrestricted(), label='Provider (ID)', ) provider = django_filters.ModelMultipleChoiceFilter( field_name='provider__slug', - queryset=Provider.objects.all(), + queryset=Provider.objects.unrestricted(), to_field_name='slug', label='Provider (slug)', ) type_id = django_filters.ModelMultipleChoiceFilter( - queryset=CircuitType.objects.all(), + queryset=CircuitType.objects.unrestricted(), label='Circuit type (ID)', ) type = django_filters.ModelMultipleChoiceFilter( field_name='type__slug', - queryset=CircuitType.objects.all(), + queryset=CircuitType.objects.unrestricted(), to_field_name='slug', label='Circuit type (slug)', ) @@ -103,23 +103,23 @@ class CircuitFilterSet(BaseFilterSet, CustomFieldFilterSet, TenancyFilterSet, Cr ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='terminations__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='terminations__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='terminations__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='terminations__site__region', lookup_expr='in', to_field_name='slug', @@ -150,16 +150,16 @@ class CircuitTerminationFilterSet(BaseFilterSet): label='Search', ) circuit_id = django_filters.ModelMultipleChoiceFilter( - queryset=Circuit.objects.all(), + queryset=Circuit.objects.unrestricted(), label='Circuit', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) diff --git a/netbox/circuits/forms.py b/netbox/circuits/forms.py index 427dc2e89..341a7a9b7 100644 --- a/netbox/circuits/forms.py +++ b/netbox/circuits/forms.py @@ -1,10 +1,10 @@ from django import forms -from taggit.forms import TagField from dcim.models import Region, Site from extras.forms import ( AddRemoveTagsForm, CustomFieldBulkEditForm, CustomFieldFilterForm, CustomFieldModelForm, CustomFieldModelCSVForm, ) +from extras.models import Tag from tenancy.forms import TenancyFilterForm, TenancyForm from tenancy.models import Tenant from utilities.forms import ( @@ -23,7 +23,8 @@ from .models import Circuit, CircuitTermination, CircuitType, Provider class ProviderForm(BootstrapMixin, CustomFieldModelForm): slug = SlugField() comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -165,7 +166,8 @@ class CircuitForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): queryset=CircuitType.objects.all() ) comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) diff --git a/netbox/circuits/models.py b/netbox/circuits/models.py index 57d41a994..dcf1c5118 100644 --- a/netbox/circuits/models.py +++ b/netbox/circuits/models.py @@ -8,6 +8,7 @@ from dcim.fields import ASNField from dcim.models import CableTermination from extras.models import CustomFieldModel, ObjectChange, TaggedItem from extras.utils import extras_features +from utilities.querysets import RestrictedQuerySet from utilities.models import ChangeLoggedModel from utilities.utils import serialize_object from .choices import * @@ -66,9 +67,10 @@ class Provider(ChangeLoggedModel, CustomFieldModel): content_type_field='obj_type', object_id_field='obj_id' ) - tags = TaggableManager(through=TaggedItem) + objects = RestrictedQuerySet.as_manager() + csv_headers = [ 'name', 'slug', 'asn', 'account', 'portal_url', 'noc_contact', 'admin_contact', 'comments', ] @@ -115,6 +117,8 @@ class CircuitType(ChangeLoggedModel): blank=True, ) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'slug', 'description'] class Meta: @@ -300,6 +304,8 @@ class CircuitTermination(CableTermination): blank=True ) + objects = RestrictedQuerySet.as_manager() + class Meta: ordering = ['circuit', 'term_side'] unique_together = ['circuit', 'term_side'] diff --git a/netbox/circuits/querysets.py b/netbox/circuits/querysets.py index 60956f32a..8a9bd50a4 100644 --- a/netbox/circuits/querysets.py +++ b/netbox/circuits/querysets.py @@ -1,7 +1,9 @@ -from django.db.models import OuterRef, QuerySet, Subquery +from django.db.models import OuterRef, Subquery + +from utilities.querysets import RestrictedQuerySet -class CircuitQuerySet(QuerySet): +class CircuitQuerySet(RestrictedQuerySet): def annotate_sites(self): """ diff --git a/netbox/circuits/tests/test_api.py b/netbox/circuits/tests/test_api.py index b5f8758e7..4e062cc1a 100644 --- a/netbox/circuits/tests/test_api.py +++ b/netbox/circuits/tests/test_api.py @@ -1,443 +1,189 @@ from django.contrib.contenttypes.models import ContentType from django.urls import reverse -from rest_framework import status from circuits.choices import * from circuits.models import Circuit, CircuitTermination, CircuitType, Provider from dcim.models import Site from extras.models import Graph -from utilities.testing import APITestCase +from utilities.testing import APITestCase, APIViewTestCases class AppTest(APITestCase): def test_root(self): - url = reverse('circuits-api:api-root') response = self.client.get('{}?format=api'.format(url), **self.header) self.assertEqual(response.status_code, 200) -class ProviderTest(APITestCase): +class ProviderTest(APIViewTestCases.APIViewTestCase): + model = Provider + brief_fields = ['circuit_count', 'id', 'name', 'slug', 'url'] + create_data = [ + { + 'name': 'Provider 4', + 'slug': 'provider-4', + }, + { + 'name': 'Provider 5', + 'slug': 'provider-5', + }, + { + 'name': 'Provider 6', + 'slug': 'provider-6', + }, + ] - def setUp(self): + @classmethod + def setUpTestData(cls): - super().setUp() - - self.provider1 = Provider.objects.create(name='Test Provider 1', slug='test-provider-1') - self.provider2 = Provider.objects.create(name='Test Provider 2', slug='test-provider-2') - self.provider3 = Provider.objects.create(name='Test Provider 3', slug='test-provider-3') - - def test_get_provider(self): - - url = reverse('circuits-api:provider-detail', kwargs={'pk': self.provider1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.provider1.name) + providers = ( + Provider(name='Provider 1', slug='provider-1'), + Provider(name='Provider 2', slug='provider-2'), + Provider(name='Provider 3', slug='provider-3'), + ) + Provider.objects.bulk_create(providers) def test_get_provider_graphs(self): + """ + Test retrieval of Graphs assigned to Providers. + """ + provider = self.model.objects.first() + ct = ContentType.objects.get(app_label='circuits', model='provider') + graphs = ( + Graph(type=ct, name='Graph 1', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1'), + Graph(type=ct, name='Graph 2', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=2'), + Graph(type=ct, name='Graph 3', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=3'), + ) + Graph.objects.bulk_create(graphs) - provider_ct = ContentType.objects.get(app_label='circuits', model='provider') - self.graph1 = Graph.objects.create( - type=provider_ct, - name='Test Graph 1', - source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1' - ) - self.graph2 = Graph.objects.create( - type=provider_ct, - name='Test Graph 2', - source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=2' - ) - self.graph3 = Graph.objects.create( - type=provider_ct, - name='Test Graph 3', - source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=3' - ) - - url = reverse('circuits-api:provider-graphs', kwargs={'pk': self.provider1.pk}) + self.add_permissions('circuits.view_provider') + url = reverse('circuits-api:provider-graphs', kwargs={'pk': provider.pk}) response = self.client.get(url, **self.header) self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]['embed_url'], 'http://example.com/graphs.py?provider=test-provider-1&foo=1') + self.assertEqual(response.data[0]['embed_url'], 'http://example.com/graphs.py?provider=provider-1&foo=1') - def test_list_providers(self): - url = reverse('circuits-api:provider-list') - response = self.client.get(url, **self.header) +class CircuitTypeTest(APIViewTestCases.APIViewTestCase): + model = CircuitType + brief_fields = ['circuit_count', 'id', 'name', 'slug', 'url'] + create_data = ( + { + 'name': 'Circuit Type 4', + 'slug': 'circuit-type-4', + }, + { + 'name': 'Circuit Type 5', + 'slug': 'circuit-type-5', + }, + { + 'name': 'Circuit Type 6', + 'slug': 'circuit-type-6', + }, + ) - self.assertEqual(response.data['count'], 3) + @classmethod + def setUpTestData(cls): - def test_list_providers_brief(self): - - url = reverse('circuits-api:provider-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['circuit_count', 'id', 'name', 'slug', 'url'] + circuit_types = ( + CircuitType(name='Circuit Type 1', slug='circuit-type-1'), + CircuitType(name='Circuit Type 2', slug='circuit-type-2'), + CircuitType(name='Circuit Type 3', slug='circuit-type-3'), ) + CircuitType.objects.bulk_create(circuit_types) - def test_create_provider(self): - data = { - 'name': 'Test Provider 4', - 'slug': 'test-provider-4', - } +class CircuitTest(APIViewTestCases.APIViewTestCase): + model = Circuit + brief_fields = ['cid', 'id', 'url'] - url = reverse('circuits-api:provider-list') - response = self.client.post(url, data, format='json', **self.header) + @classmethod + def setUpTestData(cls): - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Provider.objects.count(), 4) - provider4 = Provider.objects.get(pk=response.data['id']) - self.assertEqual(provider4.name, data['name']) - self.assertEqual(provider4.slug, data['slug']) + providers = ( + Provider(name='Provider 1', slug='provider-1'), + Provider(name='Provider 2', slug='provider-2'), + ) + Provider.objects.bulk_create(providers) - def test_create_provider_bulk(self): + circuit_types = ( + CircuitType(name='Circuit Type 1', slug='circuit-type-1'), + CircuitType(name='Circuit Type 2', slug='circuit-type-2'), + ) + CircuitType.objects.bulk_create(circuit_types) - data = [ + circuits = ( + Circuit(cid='Circuit 1', provider=providers[0], type=circuit_types[0]), + Circuit(cid='Circuit 2', provider=providers[0], type=circuit_types[0]), + Circuit(cid='Circuit 3', provider=providers[0], type=circuit_types[0]), + ) + Circuit.objects.bulk_create(circuits) + + cls.create_data = [ { - 'name': 'Test Provider 4', - 'slug': 'test-provider-4', + 'cid': 'Circuit 4', + 'provider': providers[1].pk, + 'type': circuit_types[1].pk, }, { - 'name': 'Test Provider 5', - 'slug': 'test-provider-5', + 'cid': 'Circuit 5', + 'provider': providers[1].pk, + 'type': circuit_types[1].pk, }, { - 'name': 'Test Provider 6', - 'slug': 'test-provider-6', + 'cid': 'Circuit 6', + 'provider': providers[1].pk, + 'type': circuit_types[1].pk, }, ] - url = reverse('circuits-api:provider-list') - response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Provider.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) +class CircuitTerminationTest(APIViewTestCases.APIViewTestCase): + model = CircuitTermination + brief_fields = ['circuit', 'id', 'term_side', 'url'] - def test_update_provider(self): + @classmethod + def setUpTestData(cls): + SIDE_A = CircuitTerminationSideChoices.SIDE_A + SIDE_Z = CircuitTerminationSideChoices.SIDE_Z - data = { - 'name': 'Test Provider X', - 'slug': 'test-provider-x', - } - - url = reverse('circuits-api:provider-detail', kwargs={'pk': self.provider1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Provider.objects.count(), 3) - provider1 = Provider.objects.get(pk=response.data['id']) - self.assertEqual(provider1.name, data['name']) - self.assertEqual(provider1.slug, data['slug']) - - def test_delete_provider(self): - - url = reverse('circuits-api:provider-detail', kwargs={'pk': self.provider1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Provider.objects.count(), 2) - - -class CircuitTypeTest(APITestCase): - - def setUp(self): - - super().setUp() - - self.circuittype1 = CircuitType.objects.create(name='Test Circuit Type 1', slug='test-circuit-type-1') - self.circuittype2 = CircuitType.objects.create(name='Test Circuit Type 2', slug='test-circuit-type-2') - self.circuittype3 = CircuitType.objects.create(name='Test Circuit Type 3', slug='test-circuit-type-3') - - def test_get_circuittype(self): - - url = reverse('circuits-api:circuittype-detail', kwargs={'pk': self.circuittype1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.circuittype1.name) - - def test_list_circuittypes(self): - - url = reverse('circuits-api:circuittype-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_circuittypes_brief(self): - - url = reverse('circuits-api:circuittype-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['circuit_count', 'id', 'name', 'slug', 'url'] + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), ) + Site.objects.bulk_create(sites) - def test_create_circuittype(self): + provider = Provider.objects.create(name='Provider 1', slug='provider-1') + circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1') - data = { - 'name': 'Test Circuit Type 4', - 'slug': 'test-circuit-type-4', - } - - url = reverse('circuits-api:circuittype-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(CircuitType.objects.count(), 4) - circuittype4 = CircuitType.objects.get(pk=response.data['id']) - self.assertEqual(circuittype4.name, data['name']) - self.assertEqual(circuittype4.slug, data['slug']) - - def test_update_circuittype(self): - - data = { - 'name': 'Test Circuit Type X', - 'slug': 'test-circuit-type-x', - } - - url = reverse('circuits-api:circuittype-detail', kwargs={'pk': self.circuittype1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(CircuitType.objects.count(), 3) - circuittype1 = CircuitType.objects.get(pk=response.data['id']) - self.assertEqual(circuittype1.name, data['name']) - self.assertEqual(circuittype1.slug, data['slug']) - - def test_delete_circuittype(self): - - url = reverse('circuits-api:circuittype-detail', kwargs={'pk': self.circuittype1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(CircuitType.objects.count(), 2) - - -class CircuitTest(APITestCase): - - def setUp(self): - - super().setUp() - - self.provider1 = Provider.objects.create(name='Test Provider 1', slug='test-provider-1') - self.provider2 = Provider.objects.create(name='Test Provider 2', slug='test-provider-2') - self.circuittype1 = CircuitType.objects.create(name='Test Circuit Type 1', slug='test-circuit-type-1') - self.circuittype2 = CircuitType.objects.create(name='Test Circuit Type 2', slug='test-circuit-type-2') - self.circuit1 = Circuit.objects.create(cid='TEST0001', provider=self.provider1, type=self.circuittype1) - self.circuit2 = Circuit.objects.create(cid='TEST0002', provider=self.provider1, type=self.circuittype1) - self.circuit3 = Circuit.objects.create(cid='TEST0003', provider=self.provider1, type=self.circuittype1) - - def test_get_circuit(self): - - url = reverse('circuits-api:circuit-detail', kwargs={'pk': self.circuit1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['cid'], self.circuit1.cid) - - def test_list_circuits(self): - - url = reverse('circuits-api:circuit-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_circuits_brief(self): - - url = reverse('circuits-api:circuit-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['cid', 'id', 'url'] + circuits = ( + Circuit(cid='Circuit 1', provider=provider, type=circuit_type), + Circuit(cid='Circuit 2', provider=provider, type=circuit_type), + Circuit(cid='Circuit 3', provider=provider, type=circuit_type), ) + Circuit.objects.bulk_create(circuits) - def test_create_circuit(self): + circuit_terminations = ( + CircuitTermination(circuit=circuits[0], site=sites[0], port_speed=100000, term_side=SIDE_A), + CircuitTermination(circuit=circuits[0], site=sites[1], port_speed=100000, term_side=SIDE_Z), + CircuitTermination(circuit=circuits[1], site=sites[0], port_speed=100000, term_side=SIDE_A), + CircuitTermination(circuit=circuits[1], site=sites[1], port_speed=100000, term_side=SIDE_Z), + ) + CircuitTermination.objects.bulk_create(circuit_terminations) - data = { - 'cid': 'TEST0004', - 'provider': self.provider1.pk, - 'type': self.circuittype1.pk, - 'status': CircuitStatusChoices.STATUS_ACTIVE, - } - - url = reverse('circuits-api:circuit-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Circuit.objects.count(), 4) - circuit4 = Circuit.objects.get(pk=response.data['id']) - self.assertEqual(circuit4.cid, data['cid']) - self.assertEqual(circuit4.provider_id, data['provider']) - self.assertEqual(circuit4.type_id, data['type']) - - def test_create_circuit_bulk(self): - - data = [ + cls.create_data = [ { - 'cid': 'TEST0004', - 'provider': self.provider1.pk, - 'type': self.circuittype1.pk, - 'status': CircuitStatusChoices.STATUS_ACTIVE, + 'circuit': circuits[2].pk, + 'term_side': SIDE_A, + 'site': sites[1].pk, + 'port_speed': 200000, }, { - 'cid': 'TEST0005', - 'provider': self.provider1.pk, - 'type': self.circuittype1.pk, - 'status': CircuitStatusChoices.STATUS_ACTIVE, - }, - { - 'cid': 'TEST0006', - 'provider': self.provider1.pk, - 'type': self.circuittype1.pk, - 'status': CircuitStatusChoices.STATUS_ACTIVE, + 'circuit': circuits[2].pk, + 'term_side': SIDE_Z, + 'site': sites[1].pk, + 'port_speed': 200000, }, ] - - url = reverse('circuits-api:circuit-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Circuit.objects.count(), 6) - self.assertEqual(response.data[0]['cid'], data[0]['cid']) - self.assertEqual(response.data[1]['cid'], data[1]['cid']) - self.assertEqual(response.data[2]['cid'], data[2]['cid']) - - def test_update_circuit(self): - - data = { - 'cid': 'TEST000X', - 'provider': self.provider2.pk, - 'type': self.circuittype2.pk, - } - - url = reverse('circuits-api:circuit-detail', kwargs={'pk': self.circuit1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Circuit.objects.count(), 3) - circuit1 = Circuit.objects.get(pk=response.data['id']) - self.assertEqual(circuit1.cid, data['cid']) - self.assertEqual(circuit1.provider_id, data['provider']) - self.assertEqual(circuit1.type_id, data['type']) - - def test_delete_circuit(self): - - url = reverse('circuits-api:circuit-detail', kwargs={'pk': self.circuit1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Circuit.objects.count(), 2) - - -class CircuitTerminationTest(APITestCase): - - def setUp(self): - - super().setUp() - - self.site1 = Site.objects.create(name='Test Site 1', slug='test-site-1') - self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') - provider = Provider.objects.create(name='Test Provider', slug='test-provider') - circuittype = CircuitType.objects.create(name='Test Circuit Type', slug='test-circuit-type') - self.circuit1 = Circuit.objects.create(cid='TEST0001', provider=provider, type=circuittype) - self.circuit2 = Circuit.objects.create(cid='TEST0002', provider=provider, type=circuittype) - self.circuit3 = Circuit.objects.create(cid='TEST0003', provider=provider, type=circuittype) - self.circuittermination1 = CircuitTermination.objects.create( - circuit=self.circuit1, - term_side=CircuitTerminationSideChoices.SIDE_A, - site=self.site1, - port_speed=1000000 - ) - self.circuittermination2 = CircuitTermination.objects.create( - circuit=self.circuit1, - term_side=CircuitTerminationSideChoices.SIDE_Z, - site=self.site2, - port_speed=1000000 - ) - self.circuittermination3 = CircuitTermination.objects.create( - circuit=self.circuit2, - term_side=CircuitTerminationSideChoices.SIDE_A, - site=self.site1, - port_speed=1000000 - ) - self.circuittermination4 = CircuitTermination.objects.create( - circuit=self.circuit2, - term_side=CircuitTerminationSideChoices.SIDE_Z, - site=self.site2, - port_speed=1000000 - ) - - def test_get_circuittermination(self): - - url = reverse('circuits-api:circuittermination-detail', kwargs={'pk': self.circuittermination1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['id'], self.circuittermination1.pk) - - def test_list_circuitterminations(self): - - url = reverse('circuits-api:circuittermination-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 4) - - def test_create_circuittermination(self): - - data = { - 'circuit': self.circuit3.pk, - 'term_side': CircuitTerminationSideChoices.SIDE_A, - 'site': self.site1.pk, - 'port_speed': 1000000, - } - - url = reverse('circuits-api:circuittermination-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(CircuitTermination.objects.count(), 5) - circuittermination4 = CircuitTermination.objects.get(pk=response.data['id']) - self.assertEqual(circuittermination4.circuit_id, data['circuit']) - self.assertEqual(circuittermination4.term_side, data['term_side']) - self.assertEqual(circuittermination4.site_id, data['site']) - self.assertEqual(circuittermination4.port_speed, data['port_speed']) - - def test_update_circuittermination(self): - - circuittermination5 = CircuitTermination.objects.create( - circuit=self.circuit3, - term_side=CircuitTerminationSideChoices.SIDE_A, - site=self.site1, - port_speed=1000000 - ) - - data = { - 'circuit': self.circuit3.pk, - 'term_side': CircuitTerminationSideChoices.SIDE_Z, - 'site': self.site2.pk, - 'port_speed': 1000000, - } - - url = reverse('circuits-api:circuittermination-detail', kwargs={'pk': circuittermination5.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(CircuitTermination.objects.count(), 5) - circuittermination1 = CircuitTermination.objects.get(pk=response.data['id']) - self.assertEqual(circuittermination1.term_side, data['term_side']) - self.assertEqual(circuittermination1.site_id, data['site']) - self.assertEqual(circuittermination1.port_speed, data['port_speed']) - - def test_delete_circuittermination(self): - - url = reverse('circuits-api:circuittermination-detail', kwargs={'pk': self.circuittermination1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(CircuitTermination.objects.count(), 3) diff --git a/netbox/circuits/tests/test_views.py b/netbox/circuits/tests/test_views.py index 9cc7af6ae..3356fca8f 100644 --- a/netbox/circuits/tests/test_views.py +++ b/netbox/circuits/tests/test_views.py @@ -17,6 +17,8 @@ class ProviderTestCase(ViewTestCases.PrimaryObjectViewTestCase): Provider(name='Provider 3', slug='provider-3', asn=65003), ]) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'name': 'Provider X', 'slug': 'provider-x', @@ -26,7 +28,7 @@ class ProviderTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'noc_contact': 'noc@example.com', 'admin_contact': 'admin@example.com', 'comments': 'Another provider', - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } cls.csv_data = ( @@ -96,6 +98,8 @@ class CircuitTestCase(ViewTestCases.PrimaryObjectViewTestCase): Circuit(cid='Circuit 3', provider=providers[0], type=circuittypes[0]), ]) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'cid': 'Circuit X', 'provider': providers[1].pk, @@ -106,7 +110,7 @@ class CircuitTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'commit_rate': 1000, 'description': 'A new circuit', 'comments': 'Some comments', - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } cls.csv_data = ( @@ -124,5 +128,4 @@ class CircuitTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'commit_rate': 2000, 'description': 'New description', 'comments': 'New comments', - } diff --git a/netbox/circuits/urls.py b/netbox/circuits/urls.py index 72d9720df..1c0f0715b 100644 --- a/netbox/circuits/urls.py +++ b/netbox/circuits/urls.py @@ -10,7 +10,7 @@ urlpatterns = [ # Providers path('providers/', views.ProviderListView.as_view(), name='provider_list'), - path('providers/add/', views.ProviderCreateView.as_view(), name='provider_add'), + path('providers/add/', views.ProviderEditView.as_view(), name='provider_add'), path('providers/import/', views.ProviderBulkImportView.as_view(), name='provider_import'), path('providers/edit/', views.ProviderBulkEditView.as_view(), name='provider_bulk_edit'), path('providers/delete/', views.ProviderBulkDeleteView.as_view(), name='provider_bulk_delete'), @@ -21,7 +21,7 @@ urlpatterns = [ # Circuit types path('circuit-types/', views.CircuitTypeListView.as_view(), name='circuittype_list'), - path('circuit-types/add/', views.CircuitTypeCreateView.as_view(), name='circuittype_add'), + path('circuit-types/add/', views.CircuitTypeEditView.as_view(), name='circuittype_add'), path('circuit-types/import/', views.CircuitTypeBulkImportView.as_view(), name='circuittype_import'), path('circuit-types/delete/', views.CircuitTypeBulkDeleteView.as_view(), name='circuittype_bulk_delete'), path('circuit-types//edit/', views.CircuitTypeEditView.as_view(), name='circuittype_edit'), @@ -29,7 +29,7 @@ urlpatterns = [ # Circuits path('circuits/', views.CircuitListView.as_view(), name='circuit_list'), - path('circuits/add/', views.CircuitCreateView.as_view(), name='circuit_add'), + path('circuits/add/', views.CircuitEditView.as_view(), name='circuit_add'), path('circuits/import/', views.CircuitBulkImportView.as_view(), name='circuit_import'), path('circuits/edit/', views.CircuitBulkEditView.as_view(), name='circuit_bulk_edit'), path('circuits/delete/', views.CircuitBulkDeleteView.as_view(), name='circuit_bulk_delete'), @@ -37,11 +37,10 @@ urlpatterns = [ path('circuits//edit/', views.CircuitEditView.as_view(), name='circuit_edit'), path('circuits//delete/', views.CircuitDeleteView.as_view(), name='circuit_delete'), path('circuits//changelog/', ObjectChangeLogView.as_view(), name='circuit_changelog', kwargs={'model': Circuit}), - path('circuits//terminations/swap/', views.circuit_terminations_swap, name='circuit_terminations_swap'), + path('circuits//terminations/swap/', views.CircuitSwapTerminations.as_view(), name='circuit_terminations_swap'), # Circuit terminations - - path('circuits//terminations/add/', views.CircuitTerminationCreateView.as_view(), name='circuittermination_add'), + path('circuits//terminations/add/', views.CircuitTerminationEditView.as_view(), name='circuittermination_add'), path('circuit-terminations//edit/', views.CircuitTerminationEditView.as_view(), name='circuittermination_edit'), path('circuit-terminations//delete/', views.CircuitTerminationDeleteView.as_view(), name='circuittermination_delete'), path('circuit-terminations//connect//', CableCreateView.as_view(), name='circuittermination_connect', kwargs={'termination_a_type': CircuitTermination}), diff --git a/netbox/circuits/views.py b/netbox/circuits/views.py index 0546b3832..f100dd3c7 100644 --- a/netbox/circuits/views.py +++ b/netbox/circuits/views.py @@ -1,18 +1,15 @@ from django.conf import settings from django.contrib import messages -from django.contrib.auth.decorators import permission_required -from django.contrib.auth.mixins import PermissionRequiredMixin from django.db import transaction -from django.db.models import Count, OuterRef, Subquery +from django.db.models import Count, OuterRef from django.shortcuts import get_object_or_404, redirect, render -from django.views.generic import View from django_tables2 import RequestConfig from extras.models import Graph from utilities.forms import ConfirmationForm from utilities.paginator import EnhancedPaginator from utilities.views import ( - BulkDeleteView, BulkEditView, BulkImportView, ObjectDeleteView, ObjectEditView, ObjectListView, + BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView, ) from . import filters, forms, tables from .choices import CircuitTerminationSideChoices @@ -23,21 +20,20 @@ from .models import Circuit, CircuitTermination, CircuitType, Provider # Providers # -class ProviderListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'circuits.view_provider' +class ProviderListView(ObjectListView): queryset = Provider.objects.annotate(count_circuits=Count('circuits')) filterset = filters.ProviderFilterSet filterset_form = forms.ProviderFilterForm table = tables.ProviderTable -class ProviderView(PermissionRequiredMixin, View): - permission_required = 'circuits.view_provider' +class ProviderView(ObjectView): + queryset = Provider.objects.all() def get(self, request, slug): - provider = get_object_or_404(Provider, slug=slug) - circuits = Circuit.objects.filter( + provider = get_object_or_404(self.queryset, slug=slug) + circuits = Circuit.objects.restrict(request.user, 'view').filter( provider=provider ).prefetch_related( 'type', 'tenant', 'terminations__site' @@ -60,33 +56,26 @@ class ProviderView(PermissionRequiredMixin, View): }) -class ProviderCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'circuits.add_provider' +class ProviderEditView(ObjectEditView): queryset = Provider.objects.all() model_form = forms.ProviderForm template_name = 'circuits/provider_edit.html' default_return_url = 'circuits:provider_list' -class ProviderEditView(ProviderCreateView): - permission_required = 'circuits.change_provider' - - -class ProviderDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'circuits.delete_provider' +class ProviderDeleteView(ObjectDeleteView): queryset = Provider.objects.all() default_return_url = 'circuits:provider_list' -class ProviderBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'circuits.add_provider' +class ProviderBulkImportView(BulkImportView): + queryset = Provider.objects.all() model_form = forms.ProviderCSVForm table = tables.ProviderTable default_return_url = 'circuits:provider_list' -class ProviderBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'circuits.change_provider' +class ProviderBulkEditView(BulkEditView): queryset = Provider.objects.annotate(count_circuits=Count('circuits')) filterset = filters.ProviderFilterSet table = tables.ProviderTable @@ -94,8 +83,7 @@ class ProviderBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'circuits:provider_list' -class ProviderBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'circuits.delete_provider' +class ProviderBulkDeleteView(BulkDeleteView): queryset = Provider.objects.annotate(count_circuits=Count('circuits')) filterset = filters.ProviderFilterSet table = tables.ProviderTable @@ -106,32 +94,25 @@ class ProviderBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Circuit Types # -class CircuitTypeListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'circuits.view_circuittype' +class CircuitTypeListView(ObjectListView): queryset = CircuitType.objects.annotate(circuit_count=Count('circuits')) table = tables.CircuitTypeTable -class CircuitTypeCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'circuits.add_circuittype' +class CircuitTypeEditView(ObjectEditView): queryset = CircuitType.objects.all() model_form = forms.CircuitTypeForm default_return_url = 'circuits:circuittype_list' -class CircuitTypeEditView(CircuitTypeCreateView): - permission_required = 'circuits.change_circuittype' - - -class CircuitTypeBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'circuits.add_circuittype' +class CircuitTypeBulkImportView(BulkImportView): + queryset = CircuitType.objects.all() model_form = forms.CircuitTypeCSVForm table = tables.CircuitTypeTable default_return_url = 'circuits:circuittype_list' -class CircuitTypeBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'circuits.delete_circuittype' +class CircuitTypeBulkDeleteView(BulkDeleteView): queryset = CircuitType.objects.annotate(circuit_count=Count('circuits')) table = tables.CircuitTypeTable default_return_url = 'circuits:circuittype_list' @@ -141,8 +122,7 @@ class CircuitTypeBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Circuits # -class CircuitListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'circuits.view_circuit' +class CircuitListView(ObjectListView): _terminations = CircuitTermination.objects.filter(circuit=OuterRef('pk')) queryset = Circuit.objects.prefetch_related( 'provider', 'type', 'tenant', 'terminations__site' @@ -152,22 +132,27 @@ class CircuitListView(PermissionRequiredMixin, ObjectListView): table = tables.CircuitTable -class CircuitView(PermissionRequiredMixin, View): - permission_required = 'circuits.view_circuit' +class CircuitView(ObjectView): + queryset = Circuit.objects.prefetch_related('provider', 'type', 'tenant__group') def get(self, request, pk): + circuit = get_object_or_404(self.queryset, pk=pk) - circuit = get_object_or_404(Circuit.objects.prefetch_related('provider', 'type', 'tenant__group'), pk=pk) - termination_a = CircuitTermination.objects.prefetch_related( + termination_a = CircuitTermination.objects.restrict(request.user, 'view').prefetch_related( 'site__region', 'connected_endpoint__device' ).filter( circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_A ).first() - termination_z = CircuitTermination.objects.prefetch_related( + if termination_a and termination_a.connected_endpoint: + termination_a.ip_addresses = termination_a.connected_endpoint.ip_addresses.restrict(request.user, 'view') + + termination_z = CircuitTermination.objects.restrict(request.user, 'view').prefetch_related( 'site__region', 'connected_endpoint__device' ).filter( circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_Z ).first() + if termination_z and termination_z.connected_endpoint: + termination_z.ip_addresses = termination_z.connected_endpoint.ip_addresses.restrict(request.user, 'view') return render(request, 'circuits/circuit.html', { 'circuit': circuit, @@ -176,33 +161,26 @@ class CircuitView(PermissionRequiredMixin, View): }) -class CircuitCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'circuits.add_circuit' +class CircuitEditView(ObjectEditView): queryset = Circuit.objects.all() model_form = forms.CircuitForm template_name = 'circuits/circuit_edit.html' default_return_url = 'circuits:circuit_list' -class CircuitEditView(CircuitCreateView): - permission_required = 'circuits.change_circuit' - - -class CircuitDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'circuits.delete_circuit' +class CircuitDeleteView(ObjectDeleteView): queryset = Circuit.objects.all() default_return_url = 'circuits:circuit_list' -class CircuitBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'circuits.add_circuit' +class CircuitBulkImportView(BulkImportView): + queryset = Circuit.objects.all() model_form = forms.CircuitCSVForm table = tables.CircuitTable default_return_url = 'circuits:circuit_list' -class CircuitBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'circuits.change_circuit' +class CircuitBulkEditView(BulkEditView): queryset = Circuit.objects.prefetch_related('provider', 'type', 'tenant').prefetch_related('terminations__site') filterset = filters.CircuitFilterSet table = tables.CircuitTable @@ -210,33 +188,54 @@ class CircuitBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'circuits:circuit_list' -class CircuitBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'circuits.delete_circuit' +class CircuitBulkDeleteView(BulkDeleteView): queryset = Circuit.objects.prefetch_related('provider', 'type', 'tenant').prefetch_related('terminations__site') filterset = filters.CircuitFilterSet table = tables.CircuitTable default_return_url = 'circuits:circuit_list' -@permission_required('circuits.change_circuittermination') -def circuit_terminations_swap(request, pk): +class CircuitSwapTerminations(ObjectEditView): + """ + Swap the A and Z terminations of a circuit. + """ + queryset = Circuit.objects.all() - circuit = get_object_or_404(Circuit, pk=pk) - termination_a = CircuitTermination.objects.filter( - circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_A - ).first() - termination_z = CircuitTermination.objects.filter( - circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_Z - ).first() - if not termination_a and not termination_z: - messages.error(request, "No terminations have been defined for circuit {}.".format(circuit)) - return redirect('circuits:circuit', pk=circuit.pk) + def get(self, request, pk): + circuit = get_object_or_404(self.queryset, pk=pk) + form = ConfirmationForm() - if request.method == 'POST': + # Circuit must have at least one termination to swap + if not circuit.termination_a and not circuit.termination_z: + messages.error(request, "No terminations have been defined for circuit {}.".format(circuit)) + return redirect('circuits:circuit', pk=circuit.pk) + + return render(request, 'circuits/circuit_terminations_swap.html', { + 'circuit': circuit, + 'termination_a': circuit.termination_a, + 'termination_z': circuit.termination_z, + 'form': form, + 'panel_class': 'default', + 'button_class': 'primary', + 'return_url': circuit.get_absolute_url(), + }) + + def post(self, request, pk): + circuit = get_object_or_404(self.queryset, pk=pk) form = ConfirmationForm(request.POST) + if form.is_valid(): + + termination_a = CircuitTermination.objects.filter( + circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_A + ).first() + termination_z = CircuitTermination.objects.filter( + circuit=circuit, term_side=CircuitTerminationSideChoices.SIDE_Z + ).first() + if termination_a and termination_z: # Use a placeholder to avoid an IntegrityError on the (circuit, term_side) unique constraint + print('swapping') with transaction.atomic(): termination_a.term_side = '_' termination_a.save() @@ -250,29 +249,26 @@ def circuit_terminations_swap(request, pk): else: termination_z.term_side = 'A' termination_z.save() + messages.success(request, "Swapped terminations for circuit {}.".format(circuit)) return redirect('circuits:circuit', pk=circuit.pk) - else: - form = ConfirmationForm() - - return render(request, 'circuits/circuit_terminations_swap.html', { - 'circuit': circuit, - 'termination_a': termination_a, - 'termination_z': termination_z, - 'form': form, - 'panel_class': 'default', - 'button_class': 'primary', - 'return_url': circuit.get_absolute_url(), - }) + return render(request, 'circuits/circuit_terminations_swap.html', { + 'circuit': circuit, + 'termination_a': circuit.termination_a, + 'termination_z': circuit.termination_z, + 'form': form, + 'panel_class': 'default', + 'button_class': 'primary', + 'return_url': circuit.get_absolute_url(), + }) # # Circuit terminations # -class CircuitTerminationCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'circuits.add_circuittermination' +class CircuitTerminationEditView(ObjectEditView): queryset = CircuitTermination.objects.all() model_form = forms.CircuitTerminationForm template_name = 'circuits/circuittermination_edit.html' @@ -286,10 +282,5 @@ class CircuitTerminationCreateView(PermissionRequiredMixin, ObjectEditView): return obj.circuit.get_absolute_url() -class CircuitTerminationEditView(CircuitTerminationCreateView): - permission_required = 'circuits.change_circuittermination' - - -class CircuitTerminationDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'circuits.delete_circuittermination' +class CircuitTerminationDeleteView(ObjectDeleteView): queryset = CircuitTermination.objects.all() diff --git a/netbox/dcim/api/nested_serializers.py b/netbox/dcim/api/nested_serializers.py index bb2d61faa..83fcd7a2a 100644 --- a/netbox/dcim/api/nested_serializers.py +++ b/netbox/dcim/api/nested_serializers.py @@ -1,32 +1,35 @@ from rest_framework import serializers from dcim.constants import CONNECTION_STATUS_CHOICES -from dcim.models import ( - Cable, ConsolePort, ConsoleServerPort, Device, DeviceBay, DeviceType, DeviceRole, FrontPort, FrontPortTemplate, - Interface, Manufacturer, Platform, PowerFeed, PowerOutlet, PowerPanel, PowerPort, PowerPortTemplate, Rack, - RackGroup, RackRole, RearPort, RearPortTemplate, Region, Site, VirtualChassis, -) +from dcim import models from utilities.api import ChoiceField, WritableNestedSerializer __all__ = [ 'NestedCableSerializer', 'NestedConsolePortSerializer', + 'NestedConsolePortTemplateSerializer', 'NestedConsoleServerPortSerializer', + 'NestedConsoleServerPortTemplateSerializer', 'NestedDeviceBaySerializer', + 'NestedDeviceBayTemplateSerializer', 'NestedDeviceRoleSerializer', 'NestedDeviceSerializer', 'NestedDeviceTypeSerializer', 'NestedFrontPortSerializer', 'NestedFrontPortTemplateSerializer', 'NestedInterfaceSerializer', + 'NestedInterfaceTemplateSerializer', + 'NestedInventoryItemSerializer', 'NestedManufacturerSerializer', 'NestedPlatformSerializer', 'NestedPowerFeedSerializer', 'NestedPowerOutletSerializer', + 'NestedPowerOutletTemplateSerializer', 'NestedPowerPanelSerializer', 'NestedPowerPortSerializer', 'NestedPowerPortTemplateSerializer', 'NestedRackGroupSerializer', + 'NestedRackReservationSerializer', 'NestedRackRoleSerializer', 'NestedRackSerializer', 'NestedRearPortSerializer', @@ -46,7 +49,7 @@ class NestedRegionSerializer(WritableNestedSerializer): site_count = serializers.IntegerField(read_only=True) class Meta: - model = Region + model = models.Region fields = ['id', 'url', 'name', 'slug', 'site_count'] @@ -54,7 +57,7 @@ class NestedSiteSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:site-detail') class Meta: - model = Site + model = models.Site fields = ['id', 'url', 'name', 'slug'] @@ -67,7 +70,7 @@ class NestedRackGroupSerializer(WritableNestedSerializer): rack_count = serializers.IntegerField(read_only=True) class Meta: - model = RackGroup + model = models.RackGroup fields = ['id', 'url', 'name', 'slug', 'rack_count'] @@ -76,7 +79,7 @@ class NestedRackRoleSerializer(WritableNestedSerializer): rack_count = serializers.IntegerField(read_only=True) class Meta: - model = RackRole + model = models.RackRole fields = ['id', 'url', 'name', 'slug', 'rack_count'] @@ -85,10 +88,22 @@ class NestedRackSerializer(WritableNestedSerializer): device_count = serializers.IntegerField(read_only=True) class Meta: - model = Rack + model = models.Rack fields = ['id', 'url', 'name', 'display_name', 'device_count'] +class NestedRackReservationSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:rackreservation-detail') + user = serializers.SerializerMethodField(read_only=True) + + class Meta: + model = models.RackReservation + fields = ['id', 'url', 'user', 'units'] + + def get_user(self, obj): + return obj.user.username + + # # Device types # @@ -98,7 +113,7 @@ class NestedManufacturerSerializer(WritableNestedSerializer): devicetype_count = serializers.IntegerField(read_only=True) class Meta: - model = Manufacturer + model = models.Manufacturer fields = ['id', 'url', 'name', 'slug', 'devicetype_count'] @@ -108,15 +123,47 @@ class NestedDeviceTypeSerializer(WritableNestedSerializer): device_count = serializers.IntegerField(read_only=True) class Meta: - model = DeviceType + model = models.DeviceType fields = ['id', 'url', 'manufacturer', 'model', 'slug', 'display_name', 'device_count'] +class NestedConsolePortTemplateSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:consoleporttemplate-detail') + + class Meta: + model = models.ConsolePortTemplate + fields = ['id', 'url', 'name'] + + +class NestedConsoleServerPortTemplateSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:consoleserverporttemplate-detail') + + class Meta: + model = models.ConsoleServerPortTemplate + fields = ['id', 'url', 'name'] + + class NestedPowerPortTemplateSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:powerporttemplate-detail') class Meta: - model = PowerPortTemplate + model = models.PowerPortTemplate + fields = ['id', 'url', 'name'] + + +class NestedPowerOutletTemplateSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:poweroutlettemplate-detail') + + class Meta: + model = models.PowerOutletTemplate + fields = ['id', 'url', 'name'] + + +class NestedInterfaceTemplateSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:interfacetemplate-detail') + + class Meta: + model = models.InterfaceTemplate fields = ['id', 'url', 'name'] @@ -124,7 +171,7 @@ class NestedRearPortTemplateSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:rearporttemplate-detail') class Meta: - model = RearPortTemplate + model = models.RearPortTemplate fields = ['id', 'url', 'name'] @@ -132,7 +179,15 @@ class NestedFrontPortTemplateSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:frontporttemplate-detail') class Meta: - model = FrontPortTemplate + model = models.FrontPortTemplate + fields = ['id', 'url', 'name'] + + +class NestedDeviceBayTemplateSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:devicebaytemplate-detail') + + class Meta: + model = models.DeviceBayTemplate fields = ['id', 'url', 'name'] @@ -146,7 +201,7 @@ class NestedDeviceRoleSerializer(WritableNestedSerializer): virtualmachine_count = serializers.IntegerField(read_only=True) class Meta: - model = DeviceRole + model = models.DeviceRole fields = ['id', 'url', 'name', 'slug', 'device_count', 'virtualmachine_count'] @@ -156,7 +211,7 @@ class NestedPlatformSerializer(WritableNestedSerializer): virtualmachine_count = serializers.IntegerField(read_only=True) class Meta: - model = Platform + model = models.Platform fields = ['id', 'url', 'name', 'slug', 'device_count', 'virtualmachine_count'] @@ -164,7 +219,7 @@ class NestedDeviceSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:device-detail') class Meta: - model = Device + model = models.Device fields = ['id', 'url', 'name', 'display_name'] @@ -174,7 +229,7 @@ class NestedConsoleServerPortSerializer(WritableNestedSerializer): connection_status = ChoiceField(choices=CONNECTION_STATUS_CHOICES, read_only=True) class Meta: - model = ConsoleServerPort + model = models.ConsoleServerPort fields = ['id', 'url', 'device', 'name', 'cable', 'connection_status'] @@ -184,7 +239,7 @@ class NestedConsolePortSerializer(WritableNestedSerializer): connection_status = ChoiceField(choices=CONNECTION_STATUS_CHOICES, read_only=True) class Meta: - model = ConsolePort + model = models.ConsolePort fields = ['id', 'url', 'device', 'name', 'cable', 'connection_status'] @@ -194,7 +249,7 @@ class NestedPowerOutletSerializer(WritableNestedSerializer): connection_status = ChoiceField(choices=CONNECTION_STATUS_CHOICES, read_only=True) class Meta: - model = PowerOutlet + model = models.PowerOutlet fields = ['id', 'url', 'device', 'name', 'cable', 'connection_status'] @@ -204,7 +259,7 @@ class NestedPowerPortSerializer(WritableNestedSerializer): connection_status = ChoiceField(choices=CONNECTION_STATUS_CHOICES, read_only=True) class Meta: - model = PowerPort + model = models.PowerPort fields = ['id', 'url', 'device', 'name', 'cable', 'connection_status'] @@ -214,7 +269,7 @@ class NestedInterfaceSerializer(WritableNestedSerializer): connection_status = ChoiceField(choices=CONNECTION_STATUS_CHOICES, read_only=True) class Meta: - model = Interface + model = models.Interface fields = ['id', 'url', 'device', 'name', 'cable', 'connection_status'] @@ -223,7 +278,7 @@ class NestedRearPortSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:rearport-detail') class Meta: - model = RearPort + model = models.RearPort fields = ['id', 'url', 'device', 'name', 'cable'] @@ -232,7 +287,7 @@ class NestedFrontPortSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:frontport-detail') class Meta: - model = FrontPort + model = models.FrontPort fields = ['id', 'url', 'device', 'name', 'cable'] @@ -241,7 +296,16 @@ class NestedDeviceBaySerializer(WritableNestedSerializer): device = NestedDeviceSerializer(read_only=True) class Meta: - model = DeviceBay + model = models.DeviceBay + fields = ['id', 'url', 'device', 'name'] + + +class NestedInventoryItemSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='dcim-api:inventoryitem-detail') + device = NestedDeviceSerializer(read_only=True) + + class Meta: + model = models.InventoryItem fields = ['id', 'url', 'device', 'name'] @@ -253,7 +317,7 @@ class NestedCableSerializer(serializers.ModelSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:cable-detail') class Meta: - model = Cable + model = models.Cable fields = ['id', 'url', 'label'] @@ -267,7 +331,7 @@ class NestedVirtualChassisSerializer(WritableNestedSerializer): member_count = serializers.IntegerField(read_only=True) class Meta: - model = VirtualChassis + model = models.VirtualChassis fields = ['id', 'url', 'master', 'member_count'] @@ -280,7 +344,7 @@ class NestedPowerPanelSerializer(WritableNestedSerializer): powerfeed_count = serializers.IntegerField(read_only=True) class Meta: - model = PowerPanel + model = models.PowerPanel fields = ['id', 'url', 'name', 'powerfeed_count'] @@ -288,5 +352,5 @@ class NestedPowerFeedSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='dcim-api:powerfeed-detail') class Meta: - model = PowerFeed + model = models.PowerFeed fields = ['id', 'url', 'name'] diff --git a/netbox/dcim/api/serializers.py b/netbox/dcim/api/serializers.py index 9ac58dc3a..c684b8041 100644 --- a/netbox/dcim/api/serializers.py +++ b/netbox/dcim/api/serializers.py @@ -2,7 +2,6 @@ from django.contrib.contenttypes.models import ContentType from drf_yasg.utils import swagger_serializer_method from rest_framework import serializers from rest_framework.validators import UniqueTogetherValidator -from taggit_serializer.serializers import TaggitSerializer, TagListSerializerField from dcim.choices import * from dcim.constants import * @@ -14,6 +13,7 @@ from dcim.models import ( VirtualChassis, ) from extras.api.customfields import CustomFieldModelSerializer +from extras.api.serializers import TaggedObjectSerializer from ipam.api.nested_serializers import NestedIPAddressSerializer, NestedVLANSerializer from ipam.models import VLAN from tenancy.api.nested_serializers import NestedTenantSerializer @@ -67,12 +67,11 @@ class RegionSerializer(serializers.ModelSerializer): fields = ['id', 'name', 'slug', 'parent', 'description', 'site_count'] -class SiteSerializer(TaggitSerializer, CustomFieldModelSerializer): +class SiteSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): status = ChoiceField(choices=SiteStatusChoices, required=False) region = NestedRegionSerializer(required=False, allow_null=True) tenant = NestedTenantSerializer(required=False, allow_null=True) time_zone = TimeZoneField(required=False) - tags = TagListSerializerField(required=False) circuit_count = serializers.IntegerField(read_only=True) device_count = serializers.IntegerField(read_only=True) prefix_count = serializers.IntegerField(read_only=True) @@ -112,7 +111,7 @@ class RackRoleSerializer(ValidatedModelSerializer): fields = ['id', 'name', 'slug', 'color', 'description', 'rack_count'] -class RackSerializer(TaggitSerializer, CustomFieldModelSerializer): +class RackSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): site = NestedSiteSerializer() group = NestedRackGroupSerializer(required=False, allow_null=True, default=None) tenant = NestedTenantSerializer(required=False, allow_null=True) @@ -121,7 +120,6 @@ class RackSerializer(TaggitSerializer, CustomFieldModelSerializer): type = ChoiceField(choices=RackTypeChoices, allow_blank=True, required=False) width = ChoiceField(choices=RackWidthChoices, required=False) outer_unit = ChoiceField(choices=RackDimensionUnitChoices, allow_blank=True, required=False) - tags = TagListSerializerField(required=False) device_count = serializers.IntegerField(read_only=True) powerfeed_count = serializers.IntegerField(read_only=True) @@ -161,14 +159,14 @@ class RackUnitSerializer(serializers.Serializer): device = NestedDeviceSerializer(read_only=True) -class RackReservationSerializer(ValidatedModelSerializer): +class RackReservationSerializer(TaggedObjectSerializer, ValidatedModelSerializer): rack = NestedRackSerializer() user = NestedUserSerializer() tenant = NestedTenantSerializer(required=False, allow_null=True) class Meta: model = RackReservation - fields = ['id', 'rack', 'units', 'created', 'user', 'tenant', 'description'] + fields = ['id', 'rack', 'units', 'created', 'user', 'tenant', 'description', 'tags'] class RackElevationDetailFilterSerializer(serializers.Serializer): @@ -223,10 +221,9 @@ class ManufacturerSerializer(ValidatedModelSerializer): ] -class DeviceTypeSerializer(TaggitSerializer, CustomFieldModelSerializer): +class DeviceTypeSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): manufacturer = NestedManufacturerSerializer() subdevice_role = ChoiceField(choices=SubdeviceRoleChoices, allow_blank=True, required=False) - tags = TagListSerializerField(required=False) device_count = serializers.IntegerField(read_only=True) class Meta: @@ -248,7 +245,7 @@ class ConsolePortTemplateSerializer(ValidatedModelSerializer): class Meta: model = ConsolePortTemplate - fields = ['id', 'device_type', 'name', 'type'] + fields = ['id', 'device_type', 'name', 'label', 'type'] class ConsoleServerPortTemplateSerializer(ValidatedModelSerializer): @@ -261,7 +258,7 @@ class ConsoleServerPortTemplateSerializer(ValidatedModelSerializer): class Meta: model = ConsoleServerPortTemplate - fields = ['id', 'device_type', 'name', 'type'] + fields = ['id', 'device_type', 'name', 'label', 'type'] class PowerPortTemplateSerializer(ValidatedModelSerializer): @@ -274,7 +271,7 @@ class PowerPortTemplateSerializer(ValidatedModelSerializer): class Meta: model = PowerPortTemplate - fields = ['id', 'device_type', 'name', 'type', 'maximum_draw', 'allocated_draw'] + fields = ['id', 'device_type', 'name', 'label', 'type', 'maximum_draw', 'allocated_draw'] class PowerOutletTemplateSerializer(ValidatedModelSerializer): @@ -295,7 +292,7 @@ class PowerOutletTemplateSerializer(ValidatedModelSerializer): class Meta: model = PowerOutletTemplate - fields = ['id', 'device_type', 'name', 'type', 'power_port', 'feed_leg'] + fields = ['id', 'device_type', 'name', 'label', 'type', 'power_port', 'feed_leg'] class InterfaceTemplateSerializer(ValidatedModelSerializer): @@ -304,7 +301,7 @@ class InterfaceTemplateSerializer(ValidatedModelSerializer): class Meta: model = InterfaceTemplate - fields = ['id', 'device_type', 'name', 'type', 'mgmt_only'] + fields = ['id', 'device_type', 'name', 'label', 'type', 'mgmt_only'] class RearPortTemplateSerializer(ValidatedModelSerializer): @@ -331,7 +328,7 @@ class DeviceBayTemplateSerializer(ValidatedModelSerializer): class Meta: model = DeviceBayTemplate - fields = ['id', 'device_type', 'name'] + fields = ['id', 'device_type', 'name', 'label'] # @@ -362,7 +359,7 @@ class PlatformSerializer(ValidatedModelSerializer): ] -class DeviceSerializer(TaggitSerializer, CustomFieldModelSerializer): +class DeviceSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): device_type = NestedDeviceTypeSerializer() device_role = NestedDeviceRoleSerializer() tenant = NestedTenantSerializer(required=False, allow_null=True) @@ -377,7 +374,6 @@ class DeviceSerializer(TaggitSerializer, CustomFieldModelSerializer): parent_device = serializers.SerializerMethodField() cluster = NestedClusterSerializer(required=False, allow_null=True) virtual_chassis = NestedVirtualChassisSerializer(required=False, allow_null=True) - tags = TagListSerializerField(required=False) class Meta: model = Device @@ -433,7 +429,7 @@ class DeviceNAPALMSerializer(serializers.Serializer): method = serializers.DictField() -class ConsoleServerPortSerializer(TaggitSerializer, ConnectedEndpointSerializer): +class ConsoleServerPortSerializer(TaggedObjectSerializer, ConnectedEndpointSerializer): device = NestedDeviceSerializer() type = ChoiceField( choices=ConsolePortTypeChoices, @@ -441,17 +437,16 @@ class ConsoleServerPortSerializer(TaggitSerializer, ConnectedEndpointSerializer) required=False ) cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = ConsoleServerPort fields = [ - 'id', 'device', 'name', 'type', 'description', 'connected_endpoint_type', 'connected_endpoint', + 'id', 'device', 'name', 'label', 'type', 'description', 'connected_endpoint_type', 'connected_endpoint', 'connection_status', 'cable', 'tags', ] -class ConsolePortSerializer(TaggitSerializer, ConnectedEndpointSerializer): +class ConsolePortSerializer(TaggedObjectSerializer, ConnectedEndpointSerializer): device = NestedDeviceSerializer() type = ChoiceField( choices=ConsolePortTypeChoices, @@ -459,17 +454,16 @@ class ConsolePortSerializer(TaggitSerializer, ConnectedEndpointSerializer): required=False ) cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = ConsolePort fields = [ - 'id', 'device', 'name', 'type', 'description', 'connected_endpoint_type', 'connected_endpoint', + 'id', 'device', 'name', 'label', 'type', 'description', 'connected_endpoint_type', 'connected_endpoint', 'connection_status', 'cable', 'tags', ] -class PowerOutletSerializer(TaggitSerializer, ConnectedEndpointSerializer): +class PowerOutletSerializer(TaggedObjectSerializer, ConnectedEndpointSerializer): device = NestedDeviceSerializer() type = ChoiceField( choices=PowerOutletTypeChoices, @@ -487,19 +481,16 @@ class PowerOutletSerializer(TaggitSerializer, ConnectedEndpointSerializer): cable = NestedCableSerializer( read_only=True ) - tags = TagListSerializerField( - required=False - ) class Meta: model = PowerOutlet fields = [ - 'id', 'device', 'name', 'type', 'power_port', 'feed_leg', 'description', 'connected_endpoint_type', + 'id', 'device', 'name', 'label', 'type', 'power_port', 'feed_leg', 'description', 'connected_endpoint_type', 'connected_endpoint', 'connection_status', 'cable', 'tags', ] -class PowerPortSerializer(TaggitSerializer, ConnectedEndpointSerializer): +class PowerPortSerializer(TaggedObjectSerializer, ConnectedEndpointSerializer): device = NestedDeviceSerializer() type = ChoiceField( choices=PowerPortTypeChoices, @@ -507,17 +498,16 @@ class PowerPortSerializer(TaggitSerializer, ConnectedEndpointSerializer): required=False ) cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = PowerPort fields = [ - 'id', 'device', 'name', 'type', 'maximum_draw', 'allocated_draw', 'description', 'connected_endpoint_type', + 'id', 'device', 'name', 'label', 'type', 'maximum_draw', 'allocated_draw', 'description', 'connected_endpoint_type', 'connected_endpoint', 'connection_status', 'cable', 'tags', ] -class InterfaceSerializer(TaggitSerializer, ConnectedEndpointSerializer): +class InterfaceSerializer(TaggedObjectSerializer, ConnectedEndpointSerializer): device = NestedDeviceSerializer() type = ChoiceField(choices=InterfaceTypeChoices) lag = NestedInterfaceSerializer(required=False, allow_null=True) @@ -530,13 +520,12 @@ class InterfaceSerializer(TaggitSerializer, ConnectedEndpointSerializer): many=True ) cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) count_ipaddresses = serializers.IntegerField(read_only=True) class Meta: model = Interface fields = [ - 'id', 'device', 'name', 'type', 'enabled', 'lag', 'mtu', 'mac_address', 'mgmt_only', 'description', + 'id', 'device', 'name', 'label', 'type', 'enabled', 'lag', 'mtu', 'mac_address', 'mgmt_only', 'description', 'connected_endpoint_type', 'connected_endpoint', 'connection_status', 'cable', 'mode', 'untagged_vlan', 'tagged_vlans', 'tags', 'count_ipaddresses', ] @@ -562,11 +551,10 @@ class InterfaceSerializer(TaggitSerializer, ConnectedEndpointSerializer): return super().validate(data) -class RearPortSerializer(TaggitSerializer, ValidatedModelSerializer): +class RearPortSerializer(TaggedObjectSerializer, ValidatedModelSerializer): device = NestedDeviceSerializer() type = ChoiceField(choices=PortTypeChoices) cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = RearPort @@ -584,38 +572,35 @@ class FrontPortRearPortSerializer(WritableNestedSerializer): fields = ['id', 'url', 'name'] -class FrontPortSerializer(TaggitSerializer, ValidatedModelSerializer): +class FrontPortSerializer(TaggedObjectSerializer, ValidatedModelSerializer): device = NestedDeviceSerializer() type = ChoiceField(choices=PortTypeChoices) rear_port = FrontPortRearPortSerializer() cable = NestedCableSerializer(read_only=True) - tags = TagListSerializerField(required=False) class Meta: model = FrontPort fields = ['id', 'device', 'name', 'type', 'rear_port', 'rear_port_position', 'description', 'cable', 'tags'] -class DeviceBaySerializer(TaggitSerializer, ValidatedModelSerializer): +class DeviceBaySerializer(TaggedObjectSerializer, ValidatedModelSerializer): device = NestedDeviceSerializer() installed_device = NestedDeviceSerializer(required=False, allow_null=True) - tags = TagListSerializerField(required=False) class Meta: model = DeviceBay - fields = ['id', 'device', 'name', 'description', 'installed_device', 'tags'] + fields = ['id', 'device', 'name', 'label', 'description', 'installed_device', 'tags'] # # Inventory items # -class InventoryItemSerializer(TaggitSerializer, ValidatedModelSerializer): +class InventoryItemSerializer(TaggedObjectSerializer, ValidatedModelSerializer): device = NestedDeviceSerializer() # Provide a default value to satisfy UniqueTogetherValidator parent = serializers.PrimaryKeyRelatedField(queryset=InventoryItem.objects.all(), allow_null=True, default=None) manufacturer = NestedManufacturerSerializer(required=False, allow_null=True, default=None) - tags = TagListSerializerField(required=False) class Meta: model = InventoryItem @@ -629,7 +614,7 @@ class InventoryItemSerializer(TaggitSerializer, ValidatedModelSerializer): # Cables # -class CableSerializer(ValidatedModelSerializer): +class CableSerializer(TaggedObjectSerializer, ValidatedModelSerializer): termination_a_type = ContentTypeField( queryset=ContentType.objects.filter(CABLE_TERMINATION_MODELS) ) @@ -645,7 +630,7 @@ class CableSerializer(ValidatedModelSerializer): model = Cable fields = [ 'id', 'termination_a_type', 'termination_a_id', 'termination_a', 'termination_b_type', 'termination_b_id', - 'termination_b', 'type', 'status', 'label', 'color', 'length', 'length_unit', + 'termination_b', 'type', 'status', 'label', 'color', 'length', 'length_unit', 'tags', ] def _get_termination(self, obj, side): @@ -708,9 +693,8 @@ class InterfaceConnectionSerializer(ValidatedModelSerializer): # Virtual chassis # -class VirtualChassisSerializer(TaggitSerializer, ValidatedModelSerializer): +class VirtualChassisSerializer(TaggedObjectSerializer, ValidatedModelSerializer): master = NestedDeviceSerializer() - tags = TagListSerializerField(required=False) member_count = serializers.IntegerField(read_only=True) class Meta: @@ -722,7 +706,7 @@ class VirtualChassisSerializer(TaggitSerializer, ValidatedModelSerializer): # Power panels # -class PowerPanelSerializer(ValidatedModelSerializer): +class PowerPanelSerializer(TaggedObjectSerializer, ValidatedModelSerializer): site = NestedSiteSerializer() rack_group = NestedRackGroupSerializer( required=False, @@ -733,10 +717,10 @@ class PowerPanelSerializer(ValidatedModelSerializer): class Meta: model = PowerPanel - fields = ['id', 'site', 'rack_group', 'name', 'powerfeed_count'] + fields = ['id', 'site', 'rack_group', 'name', 'tags', 'powerfeed_count'] -class PowerFeedSerializer(TaggitSerializer, CustomFieldModelSerializer): +class PowerFeedSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): power_panel = NestedPowerPanelSerializer() rack = NestedRackSerializer( required=False, @@ -759,9 +743,6 @@ class PowerFeedSerializer(TaggitSerializer, CustomFieldModelSerializer): choices=PowerFeedPhaseChoices, default=PowerFeedPhaseChoices.PHASE_SINGLE ) - tags = TagListSerializerField( - required=False - ) class Meta: model = PowerFeed diff --git a/netbox/dcim/api/views.py b/netbox/dcim/api/views.py index 9c8fe12de..324edcb49 100644 --- a/netbox/dcim/api/views.py +++ b/netbox/dcim/api/views.py @@ -395,7 +395,7 @@ class DeviceViewSet(CustomFieldModelViewSet): )) # Verify user permission - if not request.user.has_perm('dcim.napalm_read'): + if not request.user.has_perm('dcim.napalm_read_device'): return HttpResponseForbidden() # Connect to the device @@ -502,13 +502,13 @@ class InterfaceViewSet(CableTraceMixin, ModelViewSet): return Response(serializer.data) -class FrontPortViewSet(ModelViewSet): +class FrontPortViewSet(CableTraceMixin, ModelViewSet): queryset = FrontPort.objects.prefetch_related('device__device_type__manufacturer', 'rear_port', 'cable', 'tags') serializer_class = serializers.FrontPortSerializer filterset_class = filters.FrontPortFilterSet -class RearPortViewSet(ModelViewSet): +class RearPortViewSet(CableTraceMixin, ModelViewSet): queryset = RearPort.objects.prefetch_related('device__device_type__manufacturer', 'cable', 'tags') serializer_class = serializers.RearPortSerializer filterset_class = filters.RearPortFilterSet diff --git a/netbox/dcim/choices.py b/netbox/dcim/choices.py index 8433bb152..479563093 100644 --- a/netbox/dcim/choices.py +++ b/netbox/dcim/choices.py @@ -276,6 +276,10 @@ class PowerPortTypeChoices(ChoiceSet): TYPE_NEMA_L620P = 'nema-l6-20p' TYPE_NEMA_L630P = 'nema-l6-30p' TYPE_NEMA_L650P = 'nema-l6-50p' + TYPE_NEMA_L1420P = 'nema-l14-20p' + TYPE_NEMA_L1430P = 'nema-l14-30p' + TYPE_NEMA_L2120P = 'nema-l21-20p' + TYPE_NEMA_L2130P = 'nema-l21-30p' # California style TYPE_CS6361C = 'cs6361c' TYPE_CS6365C = 'cs6365c' @@ -337,6 +341,10 @@ class PowerPortTypeChoices(ChoiceSet): (TYPE_NEMA_L620P, 'NEMA L6-20P'), (TYPE_NEMA_L630P, 'NEMA L6-30P'), (TYPE_NEMA_L650P, 'NEMA L6-50P'), + (TYPE_NEMA_L1420P, 'NEMA L14-20P'), + (TYPE_NEMA_L1430P, 'NEMA L14-30P'), + (TYPE_NEMA_L2120P, 'NEMA L21-20P'), + (TYPE_NEMA_L2130P, 'NEMA L21-30P'), )), ('California Style', ( (TYPE_CS6361C, 'CS6361C'), @@ -405,6 +413,10 @@ class PowerOutletTypeChoices(ChoiceSet): TYPE_NEMA_L620R = 'nema-l6-20r' TYPE_NEMA_L630R = 'nema-l6-30r' TYPE_NEMA_L650R = 'nema-l6-50r' + TYPE_NEMA_L1420R = 'nema-l14-20r' + TYPE_NEMA_L1430R = 'nema-l14-30r' + TYPE_NEMA_L2120R = 'nema-l21-20r' + TYPE_NEMA_L2130R = 'nema-l21-30r' # California style TYPE_CS6360C = 'CS6360C' TYPE_CS6364C = 'CS6364C' @@ -467,6 +479,10 @@ class PowerOutletTypeChoices(ChoiceSet): (TYPE_NEMA_L620R, 'NEMA L6-20R'), (TYPE_NEMA_L630R, 'NEMA L6-30R'), (TYPE_NEMA_L650R, 'NEMA L6-50R'), + (TYPE_NEMA_L1420R, 'NEMA L14-20R'), + (TYPE_NEMA_L1430R, 'NEMA L14-30R'), + (TYPE_NEMA_L2120R, 'NEMA L21-20R'), + (TYPE_NEMA_L2130R, 'NEMA L21-30R'), )), ('California Style', ( (TYPE_CS6360C, 'CS6360C'), diff --git a/netbox/dcim/filters.py b/netbox/dcim/filters.py index 5bc6dd7f0..d22511ede 100644 --- a/netbox/dcim/filters.py +++ b/netbox/dcim/filters.py @@ -4,7 +4,7 @@ from django.contrib.auth.models import User from extras.filters import CustomFieldFilterSet, LocalConfigContextFilterSet, CreatedUpdatedFilterSet from tenancy.filters import TenancyFilterSet from tenancy.models import Tenant -from utilities.constants import COLOR_CHOICES +from utilities.choices import ColorChoices from utilities.filters import ( BaseFilterSet, MultiValueCharFilter, MultiValueMACAddressFilter, MultiValueNumberFilter, NameSlugSearchFilterSet, TagFilter, TreeNodeMultipleChoiceFilter, @@ -62,12 +62,12 @@ __all__ = ( class RegionFilterSet(BaseFilterSet, NameSlugSearchFilterSet): parent_id = django_filters.ModelMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), label='Parent region (ID)', ) parent = django_filters.ModelMultipleChoiceFilter( field_name='parent__slug', - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), to_field_name='slug', label='Parent region (slug)', ) @@ -87,13 +87,13 @@ class SiteFilterSet(BaseFilterSet, TenancyFilterSet, CustomFieldFilterSet, Creat null_value=None ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='region', lookup_expr='in', to_field_name='slug', @@ -131,35 +131,35 @@ class SiteFilterSet(BaseFilterSet, TenancyFilterSet, CustomFieldFilterSet, Creat class RackGroupFilterSet(BaseFilterSet, NameSlugSearchFilterSet): region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', to_field_name='slug', label='Region (slug)', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) parent_id = django_filters.ModelMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), label='Rack group (ID)', ) parent = django_filters.ModelMultipleChoiceFilter( field_name='parent__slug', - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), to_field_name='slug', label='Rack group (slug)', ) @@ -182,36 +182,36 @@ class RackFilterSet(BaseFilterSet, TenancyFilterSet, CustomFieldFilterSet, Creat label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', to_field_name='slug', label='Region (slug)', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) group_id = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='group', lookup_expr='in', label='Rack group (ID)', ) group = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='group', lookup_expr='in', to_field_name='slug', @@ -222,12 +222,12 @@ class RackFilterSet(BaseFilterSet, TenancyFilterSet, CustomFieldFilterSet, Creat null_value=None ) role_id = django_filters.ModelMultipleChoiceFilter( - queryset=RackRole.objects.all(), + queryset=RackRole.objects.unrestricted(), label='Role (ID)', ) role = django_filters.ModelMultipleChoiceFilter( field_name='role__slug', - queryset=RackRole.objects.all(), + queryset=RackRole.objects.unrestricted(), to_field_name='slug', label='Role (slug)', ) @@ -261,28 +261,28 @@ class RackReservationFilterSet(BaseFilterSet, TenancyFilterSet): label='Search', ) rack_id = django_filters.ModelMultipleChoiceFilter( - queryset=Rack.objects.all(), + queryset=Rack.objects.unrestricted(), label='Rack (ID)', ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='rack__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='rack__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) group_id = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='rack__group', lookup_expr='in', label='Rack group (ID)', ) group = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='rack__group', lookup_expr='in', to_field_name='slug', @@ -298,6 +298,7 @@ class RackReservationFilterSet(BaseFilterSet, TenancyFilterSet): to_field_name='username', label='User (name)', ) + tag = TagFilter() class Meta: model = RackReservation @@ -327,12 +328,12 @@ class DeviceTypeFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFil label='Search', ) manufacturer_id = django_filters.ModelMultipleChoiceFilter( - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), label='Manufacturer (ID)', ) manufacturer = django_filters.ModelMultipleChoiceFilter( field_name='manufacturer__slug', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), to_field_name='slug', label='Manufacturer (slug)', ) @@ -409,7 +410,7 @@ class DeviceTypeFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFil class DeviceTypeComponentFilterSet(NameSlugSearchFilterSet): devicetype_id = django_filters.ModelMultipleChoiceFilter( - queryset=DeviceType.objects.all(), + queryset=DeviceType.objects.unrestricted(), field_name='device_type_id', label='Device type (ID)', ) @@ -481,12 +482,12 @@ class DeviceRoleFilterSet(BaseFilterSet, NameSlugSearchFilterSet): class PlatformFilterSet(BaseFilterSet, NameSlugSearchFilterSet): manufacturer_id = django_filters.ModelMultipleChoiceFilter( field_name='manufacturer', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), label='Manufacturer (ID)', ) manufacturer = django_filters.ModelMultipleChoiceFilter( field_name='manufacturer__slug', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), to_field_name='slug', label='Manufacturer (slug)', ) @@ -509,81 +510,81 @@ class DeviceFilterSet( ) manufacturer_id = django_filters.ModelMultipleChoiceFilter( field_name='device_type__manufacturer', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), label='Manufacturer (ID)', ) manufacturer = django_filters.ModelMultipleChoiceFilter( field_name='device_type__manufacturer__slug', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), to_field_name='slug', label='Manufacturer (slug)', ) device_type_id = django_filters.ModelMultipleChoiceFilter( - queryset=DeviceType.objects.all(), + queryset=DeviceType.objects.unrestricted(), label='Device type (ID)', ) role_id = django_filters.ModelMultipleChoiceFilter( field_name='device_role_id', - queryset=DeviceRole.objects.all(), + queryset=DeviceRole.objects.unrestricted(), label='Role (ID)', ) role = django_filters.ModelMultipleChoiceFilter( field_name='device_role__slug', - queryset=DeviceRole.objects.all(), + queryset=DeviceRole.objects.unrestricted(), to_field_name='slug', label='Role (slug)', ) platform_id = django_filters.ModelMultipleChoiceFilter( - queryset=Platform.objects.all(), + queryset=Platform.objects.unrestricted(), label='Platform (ID)', ) platform = django_filters.ModelMultipleChoiceFilter( field_name='platform__slug', - queryset=Platform.objects.all(), + queryset=Platform.objects.unrestricted(), to_field_name='slug', label='Platform (slug)', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', to_field_name='slug', label='Region (slug)', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) rack_group_id = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='rack__group', lookup_expr='in', label='Rack group (ID)', ) rack_id = django_filters.ModelMultipleChoiceFilter( field_name='rack', - queryset=Rack.objects.all(), + queryset=Rack.objects.unrestricted(), label='Rack (ID)', ) cluster_id = django_filters.ModelMultipleChoiceFilter( - queryset=Cluster.objects.all(), + queryset=Cluster.objects.unrestricted(), label='VM cluster (ID)', ) model = django_filters.ModelMultipleChoiceFilter( field_name='device_type__slug', - queryset=DeviceType.objects.all(), + queryset=DeviceType.objects.unrestricted(), to_field_name='slug', label='Device model (slug)', ) @@ -608,7 +609,7 @@ class DeviceFilterSet( ) virtual_chassis_id = django_filters.ModelMultipleChoiceFilter( field_name='virtual_chassis', - queryset=VirtualChassis.objects.all(), + queryset=VirtualChassis.objects.unrestricted(), label='Virtual chassis (ID)', ) virtual_chassis_member = django_filters.BooleanFilter( @@ -706,13 +707,13 @@ class DeviceComponentFilterSet(django_filters.FilterSet): label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='device__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='device__site__region', lookup_expr='in', to_field_name='slug', @@ -720,22 +721,22 @@ class DeviceComponentFilterSet(django_filters.FilterSet): ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='device__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='device__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) device_id = django_filters.ModelMultipleChoiceFilter( - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), label='Device (ID)', ) device = django_filters.ModelMultipleChoiceFilter( field_name='device__name', - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), to_field_name='name', label='Device (name)', ) @@ -842,7 +843,7 @@ class InterfaceFilterSet(BaseFilterSet, DeviceComponentFilterSet): ) lag_id = django_filters.ModelMultipleChoiceFilter( field_name='lag', - queryset=Interface.objects.all(), + queryset=Interface.objects.unrestricted(), label='LAG interface (ID)', ) mac_address = MultiValueMACAddressFilter() @@ -949,13 +950,13 @@ class InventoryItemFilterSet(BaseFilterSet, DeviceComponentFilterSet): label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='device__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='device__site__region', lookup_expr='in', to_field_name='slug', @@ -963,35 +964,35 @@ class InventoryItemFilterSet(BaseFilterSet, DeviceComponentFilterSet): ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='device__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='device__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) device_id = django_filters.ModelChoiceFilter( - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), label='Device (ID)', ) device = django_filters.ModelChoiceFilter( - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), to_field_name='name', label='Device (name)', ) parent_id = django_filters.ModelMultipleChoiceFilter( - queryset=InventoryItem.objects.all(), + queryset=InventoryItem.objects.unrestricted(), label='Parent inventory item (ID)', ) manufacturer_id = django_filters.ModelMultipleChoiceFilter( - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), label='Manufacturer (ID)', ) manufacturer = django_filters.ModelMultipleChoiceFilter( field_name='manufacturer__slug', - queryset=Manufacturer.objects.all(), + queryset=Manufacturer.objects.unrestricted(), to_field_name='slug', label='Manufacturer (slug)', ) @@ -1022,13 +1023,13 @@ class VirtualChassisFilterSet(BaseFilterSet): label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='master__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='master__site__region', lookup_expr='in', to_field_name='slug', @@ -1036,23 +1037,23 @@ class VirtualChassisFilterSet(BaseFilterSet): ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='master__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='master__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) tenant_id = django_filters.ModelMultipleChoiceFilter( field_name='master__tenant', - queryset=Tenant.objects.all(), + queryset=Tenant.objects.unrestricted(), label='Tenant (ID)', ) tenant = django_filters.ModelMultipleChoiceFilter( field_name='master__tenant__slug', - queryset=Tenant.objects.all(), + queryset=Tenant.objects.unrestricted(), to_field_name='slug', label='Tenant (slug)', ) @@ -1084,7 +1085,7 @@ class CableFilterSet(BaseFilterSet): choices=CableStatusChoices ) color = django_filters.MultipleChoiceFilter( - choices=COLOR_CHOICES + choices=ColorChoices ) device_id = MultiValueNumberFilter( method='filter_device' @@ -1117,6 +1118,7 @@ class CableFilterSet(BaseFilterSet): method='filter_device', field_name='device__tenant__slug' ) + tag = TagFilter() class Meta: model = Cable @@ -1237,34 +1239,35 @@ class PowerPanelFilterSet(BaseFilterSet): label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', to_field_name='slug', label='Region (slug)', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) rack_group_id = TreeNodeMultipleChoiceFilter( - queryset=RackGroup.objects.all(), + queryset=RackGroup.objects.unrestricted(), field_name='rack_group', lookup_expr='in', label='Rack group (ID)', ) + tag = TagFilter() class Meta: model = PowerPanel @@ -1285,13 +1288,13 @@ class PowerFeedFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='power_panel__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='power_panel__site__region', lookup_expr='in', to_field_name='slug', @@ -1299,22 +1302,22 @@ class PowerFeedFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='power_panel__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='power_panel__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site name (slug)', ) power_panel_id = django_filters.ModelMultipleChoiceFilter( - queryset=PowerPanel.objects.all(), + queryset=PowerPanel.objects.unrestricted(), label='Power panel (ID)', ) rack_id = django_filters.ModelMultipleChoiceFilter( field_name='rack', - queryset=Rack.objects.all(), + queryset=Rack.objects.unrestricted(), label='Rack (ID)', ) tag = TagFilter() diff --git a/netbox/dcim/forms.py b/netbox/dcim/forms.py index b104124b4..099676181 100644 --- a/netbox/dcim/forms.py +++ b/netbox/dcim/forms.py @@ -9,7 +9,6 @@ from django.utils.safestring import mark_safe from mptt.forms import TreeNodeChoiceField from netaddr import EUI from netaddr.core import AddrFormatError -from taggit.forms import TagField from timezone_field import TimeZoneFormField from circuits.models import Circuit, Provider @@ -17,18 +16,19 @@ from extras.forms import ( AddRemoveTagsForm, CustomFieldBulkEditForm, CustomFieldModelCSVForm, CustomFieldFilterForm, CustomFieldModelForm, LocalConfigContextFilterForm, ) +from extras.models import Tag from ipam.constants import BGP_ASN_MAX, BGP_ASN_MIN from ipam.models import IPAddress, VLAN from tenancy.forms import TenancyFilterForm, TenancyForm from tenancy.models import Tenant, TenantGroup from utilities.forms import ( - APISelect, APISelectMultiple, add_blank_choice, ArrayFieldSelectMultiple, BootstrapMixin, BulkEditForm, - BulkEditNullBooleanSelect, ColorSelect, CommentField, ConfirmationForm, CSVChoiceField, CSVModelChoiceField, - CSVModelForm, DynamicModelChoiceField, DynamicModelMultipleChoiceField, ExpandableNameField, form_from_model, - JSONField, SelectWithPK, SmallTextarea, SlugField, StaticSelect2, StaticSelect2Multiple, TagFilterField, + APISelect, APISelectMultiple, add_blank_choice, BootstrapMixin, BulkEditForm, BulkEditNullBooleanSelect, + BulkRenameForm, ColorSelect, CommentField, ConfirmationForm, CSVChoiceField, CSVModelChoiceField, CSVModelForm, + DynamicModelChoiceField, DynamicModelMultipleChoiceField, ExpandableNameField, form_from_model, JSONField, + NumericArrayField, SelectWithPK, SmallTextarea, SlugField, StaticSelect2, StaticSelect2Multiple, TagFilterField, BOOLEAN_WITH_BLANK_CHOICES, ) -from virtualization.models import Cluster, ClusterGroup, VirtualMachine +from virtualization.models import Cluster, ClusterGroup from .choices import * from .constants import * from .models import ( @@ -128,28 +128,26 @@ class InterfaceCommonForm: }) -class BulkRenameForm(forms.Form): - """ - An extendable form to be used for renaming device components in bulk. - """ - find = forms.CharField() - replace = forms.CharField() - use_regex = forms.BooleanField( - required=False, - initial=True, - label='Use regular expressions' +class LabeledComponentForm(BootstrapMixin, forms.Form): + name_pattern = ExpandableNameField( + label='Name' + ) + label_pattern = ExpandableNameField( + label='Label', + required=False ) def clean(self): - # Validate regular expression in "find" field - if self.cleaned_data['use_regex']: - try: - re.compile(self.cleaned_data['find']) - except re.error: - raise forms.ValidationError({ - 'find': "Invalid regular expression" - }) + # Validate that the number of components being created from both the name_pattern and label_pattern are equal + name_pattern_count = len(self.cleaned_data['name_pattern']) + label_pattern_count = len(self.cleaned_data['label_pattern']) + if label_pattern_count and name_pattern_count != label_pattern_count: + raise forms.ValidationError({ + 'label_pattern': 'The provided name pattern will create {} components, however {} labels will ' + 'be generated. These counts must match.'.format( + name_pattern_count, label_pattern_count) + }, code='label_pattern_mismatch') # @@ -226,7 +224,8 @@ class SiteForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): ) slug = SlugField() comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -364,7 +363,12 @@ class SiteFilterForm(BootstrapMixin, TenancyFilterForm, CustomFieldFilterForm): class RackGroupForm(BootstrapMixin, forms.ModelForm): site = DynamicModelChoiceField( - queryset=Site.objects.all() + queryset=Site.objects.all(), + widget=APISelect( + filter_for={ + 'parent': 'site_id', + } + ) ) parent = DynamicModelChoiceField( queryset=RackGroup.objects.all(), @@ -482,7 +486,8 @@ class RackForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): required=False ) comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -730,51 +735,49 @@ class RackElevationFilterForm(RackFilterForm): # class RackReservationForm(BootstrapMixin, TenancyForm, forms.ModelForm): - rack = forms.ModelChoiceField( - queryset=Rack.objects.all(), + site = DynamicModelChoiceField( + queryset=Site.objects.all(), required=False, - widget=forms.HiddenInput() - ) - # TODO: Change this to an API-backed form field. We can't do this currently because we want to retain - # the multi-line '); - $('#id_tags').select2({ + $('#id_tags.tagfield').replaceWith(''); + $('#id_tags.tagfield').select2({ tags: true, data: tag_objs, multiple: true, @@ -354,14 +354,14 @@ $(document).ready(function() { } } }); - $('#id_tags').closest('form').submit(function(event){ + $('#id_tags.tagfield').closest('form').submit(function(event){ // django-taggit can only accept a single comma seperated string value - var value = $('#id_tags').val(); + var value = $('#id_tags.tagfield').val(); if (value.length > 0){ var final_tags = value.join(', '); - $('#id_tags').val(null).trigger('change'); + $('#id_tags.tagfield').val(null).trigger('change'); var option = new Option(final_tags, final_tags, true, true); - $('#id_tags').append(option).trigger('change'); + $('#id_tags.tagfield').append(option).trigger('change'); } }); diff --git a/netbox/secrets/admin.py b/netbox/secrets/admin.py index 94cd1c7fa..e11128674 100644 --- a/netbox/secrets/admin.py +++ b/netbox/secrets/admin.py @@ -23,7 +23,7 @@ class UserKeyAdmin(admin.ModelAdmin): actions = super().get_actions(request) if 'delete_selected' in actions: del actions['delete_selected'] - if not request.user.has_perm('secrets.activate_userkey'): + if not request.user.has_perm('secrets.change_userkey'): del actions['activate_selected'] return actions diff --git a/netbox/secrets/api/nested_serializers.py b/netbox/secrets/api/nested_serializers.py index 7aa8087da..13c016c18 100644 --- a/netbox/secrets/api/nested_serializers.py +++ b/netbox/secrets/api/nested_serializers.py @@ -1,13 +1,22 @@ from rest_framework import serializers -from secrets.models import SecretRole +from secrets.models import Secret, SecretRole from utilities.api import WritableNestedSerializer __all__ = [ - 'NestedSecretRoleSerializer' + 'NestedSecretRoleSerializer', + 'NestedSecretSerializer', ] +class NestedSecretSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='secrets-api:secret-detail') + + class Meta: + model = Secret + fields = ['id', 'url', 'name'] + + class NestedSecretRoleSerializer(WritableNestedSerializer): url = serializers.HyperlinkedIdentityField(view_name='secrets-api:secretrole-detail') secret_count = serializers.IntegerField(read_only=True) diff --git a/netbox/secrets/api/serializers.py b/netbox/secrets/api/serializers.py index 0b73f0002..54132dd34 100644 --- a/netbox/secrets/api/serializers.py +++ b/netbox/secrets/api/serializers.py @@ -1,8 +1,8 @@ from rest_framework import serializers -from taggit_serializer.serializers import TaggitSerializer, TagListSerializerField from dcim.api.nested_serializers import NestedDeviceSerializer from extras.api.customfields import CustomFieldModelSerializer +from extras.api.serializers import TaggedObjectSerializer from secrets.models import Secret, SecretRole from utilities.api import ValidatedModelSerializer from .nested_serializers import * @@ -20,11 +20,10 @@ class SecretRoleSerializer(ValidatedModelSerializer): fields = ['id', 'name', 'slug', 'description', 'secret_count'] -class SecretSerializer(TaggitSerializer, CustomFieldModelSerializer): +class SecretSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): device = NestedDeviceSerializer() role = NestedSecretRoleSerializer() plaintext = serializers.CharField() - tags = TagListSerializerField(required=False) class Meta: model = Secret diff --git a/netbox/secrets/api/views.py b/netbox/secrets/api/views.py index 1795e6c0a..9e330b782 100644 --- a/netbox/secrets/api/views.py +++ b/netbox/secrets/api/views.py @@ -29,7 +29,6 @@ class SecretRoleViewSet(ModelViewSet): secret_count=Count('secrets') ) serializer_class = serializers.SecretRoleSerializer - permission_classes = [IsAuthenticated] filterset_class = filters.SecretRoleFilterSet diff --git a/netbox/secrets/decorators.py b/netbox/secrets/decorators.py deleted file mode 100644 index e2f44ac90..000000000 --- a/netbox/secrets/decorators.py +++ /dev/null @@ -1,24 +0,0 @@ -from django.contrib import messages -from django.shortcuts import redirect - -from .models import UserKey - - -def userkey_required(): - """ - Decorator for views which require that the user has an active UserKey (typically for encryption/decryption of - Secrets). - """ - def _decorator(view): - def wrapped_view(request, *args, **kwargs): - try: - uk = UserKey.objects.get(user=request.user) - except UserKey.DoesNotExist: - messages.warning(request, "This operation requires an active user key, but you don't have one.") - return redirect('user:userkey') - if not uk.is_active(): - messages.warning(request, "This operation is not available. Your user key has not been activated.") - return redirect('user:userkey') - return view(request, *args, **kwargs) - return wrapped_view - return _decorator diff --git a/netbox/secrets/filters.py b/netbox/secrets/filters.py index 78f25952a..fee9b4981 100644 --- a/netbox/secrets/filters.py +++ b/netbox/secrets/filters.py @@ -26,22 +26,22 @@ class SecretFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilterS label='Search', ) role_id = django_filters.ModelMultipleChoiceFilter( - queryset=SecretRole.objects.all(), + queryset=SecretRole.objects.unrestricted(), label='Role (ID)', ) role = django_filters.ModelMultipleChoiceFilter( field_name='role__slug', - queryset=SecretRole.objects.all(), + queryset=SecretRole.objects.unrestricted(), to_field_name='slug', label='Role (slug)', ) device_id = django_filters.ModelMultipleChoiceFilter( - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), label='Device (ID)', ) device = django_filters.ModelMultipleChoiceFilter( field_name='device__name', - queryset=Device.objects.all(), + queryset=Device.objects.unrestricted(), to_field_name='name', label='Device (name)', ) diff --git a/netbox/secrets/forms.py b/netbox/secrets/forms.py index 368a47590..f62c72293 100644 --- a/netbox/secrets/forms.py +++ b/netbox/secrets/forms.py @@ -1,12 +1,12 @@ from Crypto.Cipher import PKCS1_OAEP from Crypto.PublicKey import RSA from django import forms -from taggit.forms import TagField from dcim.models import Device from extras.forms import ( AddRemoveTagsForm, CustomFieldBulkEditForm, CustomFieldFilterForm, CustomFieldModelForm, CustomFieldModelCSVForm, ) +from extras.models import Tag from utilities.forms import ( APISelectMultiple, BootstrapMixin, CSVModelChoiceField, CSVModelForm, DynamicModelChoiceField, DynamicModelMultipleChoiceField, SlugField, StaticSelect2Multiple, TagFilterField, @@ -90,7 +90,8 @@ class SecretForm(BootstrapMixin, CustomFieldModelForm): role = DynamicModelChoiceField( queryset=SecretRole.objects.all() ) - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -115,6 +116,16 @@ class SecretForm(BootstrapMixin, CustomFieldModelForm): 'plaintext2': "The two given plaintext values do not match. Please check your input." }) + # Validate uniqueness + if Secret.objects.filter( + device=self.cleaned_data['device'], + role=self.cleaned_data['role'], + name=self.cleaned_data['name'] + ).exists(): + raise forms.ValidationError( + "Each secret assigned to a device must have a unique combination of role and name" + ) + class SecretCSVForm(CustomFieldModelCSVForm): device = CSVModelChoiceField( diff --git a/netbox/secrets/migrations/0001_initial.py b/netbox/secrets/migrations/0001_initial.py index 1281a266a..3664bae63 100644 --- a/netbox/secrets/migrations/0001_initial.py +++ b/netbox/secrets/migrations/0001_initial.py @@ -56,7 +56,6 @@ class Migration(migrations.Migration): ], options={ 'ordering': ['user__username'], - 'permissions': (('activate_userkey', 'Can activate user keys for decryption'),), }, ), migrations.AddField( diff --git a/netbox/secrets/models.py b/netbox/secrets/models.py index 830e91096..bf5858ff8 100644 --- a/netbox/secrets/models.py +++ b/netbox/secrets/models.py @@ -1,5 +1,4 @@ import os -import sys from Crypto.Cipher import AES from Crypto.PublicKey import RSA @@ -18,6 +17,7 @@ from dcim.models import Device from extras.models import CustomFieldModel, TaggedItem from extras.utils import extras_features from utilities.models import ChangeLoggedModel +from utilities.querysets import RestrictedQuerySet from .exceptions import InvalidKey from .hashers import SecretValidationHasher from .querysets import UserKeyQuerySet @@ -64,9 +64,6 @@ class UserKey(models.Model): class Meta: ordering = ['user__username'] - permissions = ( - ('activate_userkey', "Can activate user keys for decryption"), - ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -269,6 +266,8 @@ class SecretRole(ChangeLoggedModel): blank=True ) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'slug', 'description'] class Meta: @@ -334,9 +333,10 @@ class Secret(ChangeLoggedModel, CustomFieldModel): content_type_field='obj_type', object_id_field='obj_id' ) - tags = TaggableManager(through=TaggedItem) + objects = RestrictedQuerySet.as_manager() + plaintext = None csv_headers = ['device', 'role', 'name', 'plaintext'] diff --git a/netbox/secrets/tests/test_api.py b/netbox/secrets/tests/test_api.py index 339c370d8..89c18b7d7 100644 --- a/netbox/secrets/tests/test_api.py +++ b/netbox/secrets/tests/test_api.py @@ -5,8 +5,7 @@ from rest_framework import status from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site from secrets.models import Secret, SecretRole, SessionKey, UserKey -from users.models import Token -from utilities.testing import APITestCase, create_test_user +from utilities.testing import APITestCase, APIViewTestCases from .constants import PRIVATE_KEY, PUBLIC_KEY @@ -20,271 +19,100 @@ class AppTest(APITestCase): self.assertEqual(response.status_code, 200) -class SecretRoleTest(APITestCase): +class SecretRoleTest(APIViewTestCases.APIViewTestCase): + model = SecretRole + brief_fields = ['id', 'name', 'secret_count', 'slug', 'url'] + create_data = [ + { + 'name': 'Secret Role 4', + 'slug': 'secret-role-4', + }, + { + 'name': 'Secret Role 5', + 'slug': 'secret-role-5', + }, + { + 'name': 'Secret Role 6', + 'slug': 'secret-role-6', + }, + ] + + @classmethod + def setUpTestData(cls): + + secret_roles = ( + SecretRole(name='Secret Role 1', slug='secret-role-1'), + SecretRole(name='Secret Role 2', slug='secret-role-2'), + SecretRole(name='Secret Role 3', slug='secret-role-3'), + ) + SecretRole.objects.bulk_create(secret_roles) + + +class SecretTest(APIViewTestCases.APIViewTestCase): + model = Secret + brief_fields = ['id', 'name', 'url'] def setUp(self): - super().setUp() - self.secretrole1 = SecretRole.objects.create(name='Test Secret Role 1', slug='test-secret-role-1') - self.secretrole2 = SecretRole.objects.create(name='Test Secret Role 2', slug='test-secret-role-2') - self.secretrole3 = SecretRole.objects.create(name='Test Secret Role 3', slug='test-secret-role-3') - - def test_get_secretrole(self): - - url = reverse('secrets-api:secretrole-detail', kwargs={'pk': self.secretrole1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.secretrole1.name) - - def test_list_secretroles(self): - - url = reverse('secrets-api:secretrole-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_secretroles_brief(self): - - url = reverse('secrets-api:secretrole-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['id', 'name', 'secret_count', 'slug', 'url'] - ) - - def test_create_secretrole(self): - - data = { - 'name': 'Test Secret Role 4', - 'slug': 'test-secret-role-4', - } - - url = reverse('secrets-api:secretrole-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(SecretRole.objects.count(), 4) - secretrole4 = SecretRole.objects.get(pk=response.data['id']) - self.assertEqual(secretrole4.name, data['name']) - self.assertEqual(secretrole4.slug, data['slug']) - - def test_create_secretrole_bulk(self): - - data = [ - { - 'name': 'Test Secret Role 4', - 'slug': 'test-secret-role-4', - }, - { - 'name': 'Test Secret Role 5', - 'slug': 'test-secret-role-5', - }, - { - 'name': 'Test Secret Role 6', - 'slug': 'test-secret-role-6', - }, - ] - - url = reverse('secrets-api:secretrole-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(SecretRole.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) - - def test_update_secretrole(self): - - data = { - 'name': 'Test SecretRole X', - 'slug': 'test-secretrole-x', - } - - url = reverse('secrets-api:secretrole-detail', kwargs={'pk': self.secretrole1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(SecretRole.objects.count(), 3) - secretrole1 = SecretRole.objects.get(pk=response.data['id']) - self.assertEqual(secretrole1.name, data['name']) - self.assertEqual(secretrole1.slug, data['slug']) - - def test_delete_secretrole(self): - - url = reverse('secrets-api:secretrole-detail', kwargs={'pk': self.secretrole1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(SecretRole.objects.count(), 2) - - -class SecretTest(APITestCase): - - def setUp(self): - - # Create a non-superuser test user - self.user = create_test_user('testuser', permissions=( - 'secrets.add_secret', - 'secrets.change_secret', - 'secrets.delete_secret', - 'secrets.view_secret', - )) - self.token = Token.objects.create(user=self.user) - self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} - + # Create a UserKey for the test user userkey = UserKey(user=self.user, public_key=PUBLIC_KEY) userkey.save() + + # Create a SessionKey for the user self.master_key = userkey.get_master_key(PRIVATE_KEY) session_key = SessionKey(userkey=userkey) session_key.save(self.master_key) - self.header = { - 'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key), - 'HTTP_X_SESSION_KEY': base64.b64encode(session_key.key), - } + # Append the session key to the test client's request header + self.header['HTTP_X_SESSION_KEY'] = base64.b64encode(session_key.key) - self.plaintexts = ( - 'Secret #1 Plaintext', - 'Secret #2 Plaintext', - 'Secret #3 Plaintext', + site = Site.objects.create(name='Site 1', slug='site-1') + manufacturer = Manufacturer.objects.create(name='Manufacturer 1', slug='manufacturer-1') + devicetype = DeviceType.objects.create(manufacturer=manufacturer, model='Device Type 1') + devicerole = DeviceRole.objects.create(name='Device Role 1', slug='device-role-1') + device = Device.objects.create(name='Device 1', site=site, device_type=devicetype, device_role=devicerole) + + secret_roles = ( + SecretRole(name='Secret Role 1', slug='secret-role-1'), + SecretRole(name='Secret Role 2', slug='secret-role-2'), ) + SecretRole.objects.bulk_create(secret_roles) - site = Site.objects.create(name='Test Site 1', slug='test-site-1') - manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') - devicetype = DeviceType.objects.create(manufacturer=manufacturer, model='Test Device Type 1') - devicerole = DeviceRole.objects.create(name='Test Device Role 1', slug='test-device-role-1') - self.device = Device.objects.create( - name='Test Device 1', site=site, device_type=devicetype, device_role=devicerole + secrets = ( + Secret(device=device, role=secret_roles[0], name='Secret 1', plaintext='ABC'), + Secret(device=device, role=secret_roles[0], name='Secret 2', plaintext='DEF'), + Secret(device=device, role=secret_roles[0], name='Secret 3', plaintext='GHI'), ) - self.secretrole1 = SecretRole.objects.create(name='Test Secret Role 1', slug='test-secret-role-1') - self.secretrole2 = SecretRole.objects.create(name='Test Secret Role 2', slug='test-secret-role-2') - self.secret1 = Secret( - device=self.device, role=self.secretrole1, name='Test Secret 1', plaintext=self.plaintexts[0] - ) - self.secret1.encrypt(self.master_key) - self.secret1.save() - self.secret2 = Secret( - device=self.device, role=self.secretrole1, name='Test Secret 2', plaintext=self.plaintexts[1] - ) - self.secret2.encrypt(self.master_key) - self.secret2.save() - self.secret3 = Secret( - device=self.device, role=self.secretrole1, name='Test Secret 3', plaintext=self.plaintexts[2] - ) - self.secret3.encrypt(self.master_key) - self.secret3.save() + for secret in secrets: + secret.encrypt(self.master_key) + secret.save() - def test_get_secret(self): - - url = reverse('secrets-api:secret-detail', kwargs={'pk': self.secret1.pk}) - - # Secret plaintext not be decrypted as the user has not been assigned to the role - response = self.client.get(url, **self.header) - self.assertIsNone(response.data['plaintext']) - - # The plaintext should be present once the user has been assigned to the role - self.secretrole1.users.add(self.user) - response = self.client.get(url, **self.header) - self.assertEqual(response.data['plaintext'], self.plaintexts[0]) - - def test_list_secrets(self): - - url = reverse('secrets-api:secret-list') - - # Secret plaintext not be decrypted as the user has not been assigned to the role - response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 3) - for secret in response.data['results']: - self.assertIsNone(secret['plaintext']) - - # The plaintext should be present once the user has been assigned to the role - self.secretrole1.users.add(self.user) - response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 3) - for i, secret in enumerate(response.data['results']): - self.assertEqual(secret['plaintext'], self.plaintexts[i]) - - def test_create_secret(self): - - data = { - 'device': self.device.pk, - 'role': self.secretrole1.pk, - 'name': 'Test Secret 4', - 'plaintext': 'Secret #4 Plaintext', - } - - url = reverse('secrets-api:secret-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(response.data['plaintext'], data['plaintext']) - self.assertEqual(Secret.objects.count(), 4) - secret4 = Secret.objects.get(pk=response.data['id']) - secret4.decrypt(self.master_key) - self.assertEqual(secret4.role_id, data['role']) - self.assertEqual(secret4.plaintext, data['plaintext']) - - def test_create_secret_bulk(self): - - data = [ + self.create_data = [ { - 'device': self.device.pk, - 'role': self.secretrole1.pk, - 'name': 'Test Secret 4', - 'plaintext': 'Secret #4 Plaintext', + 'device': device.pk, + 'role': secret_roles[1].pk, + 'name': 'Secret 4', + 'plaintext': 'JKL', }, { - 'device': self.device.pk, - 'role': self.secretrole1.pk, - 'name': 'Test Secret 5', - 'plaintext': 'Secret #5 Plaintext', + 'device': device.pk, + 'role': secret_roles[1].pk, + 'name': 'Secret 5', + 'plaintext': 'MNO', }, { - 'device': self.device.pk, - 'role': self.secretrole1.pk, - 'name': 'Test Secret 6', - 'plaintext': 'Secret #6 Plaintext', + 'device': device.pk, + 'role': secret_roles[1].pk, + 'name': 'Secret 6', + 'plaintext': 'PQR', }, ] - url = reverse('secrets-api:secret-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Secret.objects.count(), 6) - self.assertEqual(response.data[0]['plaintext'], data[0]['plaintext']) - self.assertEqual(response.data[1]['plaintext'], data[1]['plaintext']) - self.assertEqual(response.data[2]['plaintext'], data[2]['plaintext']) - - def test_update_secret(self): - - data = { - 'device': self.device.pk, - 'role': self.secretrole2.pk, - 'plaintext': 'NewPlaintext', - } - - url = reverse('secrets-api:secret-detail', kwargs={'pk': self.secret1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(response.data['plaintext'], data['plaintext']) - self.assertEqual(Secret.objects.count(), 3) - secret1 = Secret.objects.get(pk=response.data['id']) - secret1.decrypt(self.master_key) - self.assertEqual(secret1.role_id, data['role']) - self.assertEqual(secret1.plaintext, data['plaintext']) - - def test_delete_secret(self): - - url = reverse('secrets-api:secret-detail', kwargs={'pk': self.secret1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Secret.objects.count(), 2) + def prepare_instance(self, instance): + # Unlock the plaintext prior to evaluation of the instance + instance.decrypt(self.master_key) + return instance class GetSessionKeyTest(APITestCase): diff --git a/netbox/secrets/tests/test_views.py b/netbox/secrets/tests/test_views.py index 96439a10d..577ba4ef4 100644 --- a/netbox/secrets/tests/test_views.py +++ b/netbox/secrets/tests/test_views.py @@ -36,15 +36,16 @@ class SecretRoleTestCase(ViewTestCases.OrganizationalObjectViewTestCase): ) -class SecretTestCase(ViewTestCases.PrimaryObjectViewTestCase): +# TODO: Change base class to PrimaryObjectViewTestCase +class SecretTestCase( + ViewTestCases.GetObjectViewTestCase, + ViewTestCases.DeleteObjectViewTestCase, + ViewTestCases.ListObjectsViewTestCase, + ViewTestCases.BulkEditObjectsViewTestCase, + ViewTestCases.BulkDeleteObjectsViewTestCase +): model = Secret - # Disable inapplicable tests - test_create_object = None - - # TODO: Check permissions enforcement on secrets.views.secret_edit - test_edit_object = None - @classmethod def setUpTestData(cls): diff --git a/netbox/secrets/urls.py b/netbox/secrets/urls.py index a19ec6ae0..84c2da398 100644 --- a/netbox/secrets/urls.py +++ b/netbox/secrets/urls.py @@ -9,7 +9,7 @@ urlpatterns = [ # Secret roles path('secret-roles/', views.SecretRoleListView.as_view(), name='secretrole_list'), - path('secret-roles/add/', views.SecretRoleCreateView.as_view(), name='secretrole_add'), + path('secret-roles/add/', views.SecretRoleEditView.as_view(), name='secretrole_add'), path('secret-roles/import/', views.SecretRoleBulkImportView.as_view(), name='secretrole_import'), path('secret-roles/delete/', views.SecretRoleBulkDeleteView.as_view(), name='secretrole_bulk_delete'), path('secret-roles//edit/', views.SecretRoleEditView.as_view(), name='secretrole_edit'), @@ -17,12 +17,12 @@ urlpatterns = [ # Secrets path('secrets/', views.SecretListView.as_view(), name='secret_list'), - path('secrets/add/', views.secret_add, name='secret_add'), + path('secrets/add/', views.SecretEditView.as_view(), name='secret_add'), path('secrets/import/', views.SecretBulkImportView.as_view(), name='secret_import'), path('secrets/edit/', views.SecretBulkEditView.as_view(), name='secret_bulk_edit'), path('secrets/delete/', views.SecretBulkDeleteView.as_view(), name='secret_bulk_delete'), path('secrets//', views.SecretView.as_view(), name='secret'), - path('secrets//edit/', views.secret_edit, name='secret_edit'), + path('secrets//edit/', views.SecretEditView.as_view(), name='secret_edit'), path('secrets//delete/', views.SecretDeleteView.as_view(), name='secret_delete'), path('secrets//changelog/', ObjectChangeLogView.as_view(), name='secret_changelog', kwargs={'model': Secret}), diff --git a/netbox/secrets/views.py b/netbox/secrets/views.py index b40e41cb3..a5aabaecd 100644 --- a/netbox/secrets/views.py +++ b/netbox/secrets/views.py @@ -1,19 +1,17 @@ import base64 +import logging from django.contrib import messages -from django.contrib.auth.decorators import permission_required -from django.contrib.auth.mixins import PermissionRequiredMixin from django.db.models import Count from django.shortcuts import get_object_or_404, redirect, render -from django.urls import reverse -from django.views.generic import View +from django.utils.html import escape +from django.utils.safestring import mark_safe from utilities.views import ( - BulkDeleteView, BulkEditView, BulkImportView, GetReturnURLMixin, ObjectDeleteView, ObjectEditView, ObjectListView, + BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView, ) from . import filters, forms, tables -from .decorators import userkey_required -from .models import SecretRole, Secret, SessionKey +from .models import SecretRole, Secret, SessionKey, UserKey def get_session_key(request): @@ -30,32 +28,25 @@ def get_session_key(request): # Secret roles # -class SecretRoleListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'secrets.view_secretrole' +class SecretRoleListView(ObjectListView): queryset = SecretRole.objects.annotate(secret_count=Count('secrets')) table = tables.SecretRoleTable -class SecretRoleCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'secrets.add_secretrole' +class SecretRoleEditView(ObjectEditView): queryset = SecretRole.objects.all() model_form = forms.SecretRoleForm default_return_url = 'secrets:secretrole_list' -class SecretRoleEditView(SecretRoleCreateView): - permission_required = 'secrets.change_secretrole' - - -class SecretRoleBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'secrets.add_secretrole' +class SecretRoleBulkImportView(BulkImportView): + queryset = SecretRole.objects.all() model_form = forms.SecretRoleCSVForm table = tables.SecretRoleTable default_return_url = 'secrets:secretrole_list' -class SecretRoleBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'secrets.delete_secretrole' +class SecretRoleBulkDeleteView(BulkDeleteView): queryset = SecretRole.objects.annotate(secret_count=Count('secrets')) table = tables.SecretRoleTable default_return_url = 'secrets:secretrole_list' @@ -65,8 +56,7 @@ class SecretRoleBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Secrets # -class SecretListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'secrets.view_secret' +class SecretListView(ObjectListView): queryset = Secret.objects.prefetch_related('role', 'device') filterset = filters.SecretFilterSet filterset_form = forms.SecretFilterForm @@ -74,129 +64,94 @@ class SecretListView(PermissionRequiredMixin, ObjectListView): action_buttons = ('import', 'export') -class SecretView(PermissionRequiredMixin, View): - permission_required = 'secrets.view_secret' +class SecretView(ObjectView): + queryset = Secret.objects.all() def get(self, request, pk): - secret = get_object_or_404(Secret, pk=pk) + secret = get_object_or_404(self.queryset, pk=pk) return render(request, 'secrets/secret.html', { 'secret': secret, }) -@permission_required('secrets.add_secret') -@userkey_required() -def secret_add(request): +class SecretEditView(ObjectEditView): + queryset = Secret.objects.all() + model_form = forms.SecretForm + template_name = 'secrets/secret_edit.html' - secret = Secret() - session_key = get_session_key(request) + def dispatch(self, request, *args, **kwargs): + + # Check that the user has a valid UserKey + try: + uk = UserKey.objects.get(user=request.user) + except UserKey.DoesNotExist: + messages.warning(request, "This operation requires an active user key, but you don't have one.") + return redirect('user:userkey') + if not uk.is_active(): + messages.warning(request, "This operation is not available. Your user key has not been activated.") + return redirect('user:userkey') + + return super().dispatch(request, *args, **kwargs) + + def post(self, request, *args, **kwargs): + logger = logging.getLogger('netbox.views.ObjectEditView') + session_key = get_session_key(request) + secret = self.get_object(kwargs) + form = self.model_form(request.POST, instance=secret) - if request.method == 'POST': - form = forms.SecretForm(request.POST, instance=secret) if form.is_valid(): + logger.debug("Form validation was successful") - # We need a valid session key in order to create a Secret - if session_key is None: + # We must have a session key in order to create a secret or update the plaintext of an existing secret + if (form.cleaned_data['plaintext'] or secret.pk is None) and session_key is None: + logger.debug("Unable to proceed: No session key was provided with the request") form.add_error(None, "No session key was provided with the request. Unable to encrypt secret data.") - # Create and encrypt the new Secret else: master_key = None try: sk = SessionKey.objects.get(userkey__user=request.user) master_key = sk.get_master_key(session_key) except SessionKey.DoesNotExist: + logger.debug("Unable to proceed: User has no session key assigned") form.add_error(None, "No session key found for this user.") if master_key is not None: + logger.debug("Successfully resolved master key for encryption") secret = form.save(commit=False) - secret.plaintext = str(form.cleaned_data['plaintext']) + if form.cleaned_data['plaintext']: + secret.plaintext = str(form.cleaned_data['plaintext']) secret.encrypt(master_key) secret.save() form.save_m2m() - messages.success(request, "Added new secret: {}.".format(secret)) - if '_addanother' in request.POST: - return redirect('secrets:secret_add') - else: - return redirect('secrets:secret', pk=secret.pk) + msg = '{} secret'.format('Created' if not form.instance.pk else 'Modified') + logger.info(f"{msg} {secret} (PK: {secret.pk})") + msg = '{} {}'.format(msg, secret.get_absolute_url(), escape(secret)) + messages.success(request, mark_safe(msg)) - else: - initial_data = { - 'device': request.GET.get('device'), - } - form = forms.SecretForm(initial=initial_data) + return redirect(self.get_return_url(request, secret)) - return render(request, 'secrets/secret_edit.html', { - 'secret': secret, - 'form': form, - 'return_url': GetReturnURLMixin().get_return_url(request, secret) - }) + else: + logger.debug("Form validation failed") + + return render(request, self.template_name, { + 'obj': secret, + 'obj_type': self.queryset.model._meta.verbose_name, + 'form': form, + 'return_url': self.get_return_url(request, secret), + }) -@permission_required('secrets.change_secret') -@userkey_required() -def secret_edit(request, pk): - - secret = get_object_or_404(Secret, pk=pk) - session_key = get_session_key(request) - - if request.method == 'POST': - form = forms.SecretForm(request.POST, instance=secret) - if form.is_valid(): - - # Re-encrypt the Secret if a plaintext and session key have been provided. - if form.cleaned_data['plaintext'] and session_key is not None: - - # Retrieve the master key using the provided session key - master_key = None - try: - sk = SessionKey.objects.get(userkey__user=request.user) - master_key = sk.get_master_key(session_key) - except SessionKey.DoesNotExist: - form.add_error(None, "No session key found for this user.") - - # Create and encrypt the new Secret - if master_key is not None: - secret = form.save(commit=False) - secret.plaintext = form.cleaned_data['plaintext'] - secret.encrypt(master_key) - secret.save() - messages.success(request, "Modified secret {}.".format(secret)) - return redirect('secrets:secret', pk=secret.pk) - else: - form.add_error(None, "Invalid session key. Unable to encrypt secret data.") - - # We can't save the plaintext without a session key. - elif form.cleaned_data['plaintext']: - form.add_error(None, "No session key was provided with the request. Unable to encrypt secret data.") - - # If no new plaintext was specified, a session key is not needed. - else: - secret = form.save() - messages.success(request, "Modified secret {}.".format(secret)) - return redirect('secrets:secret', pk=secret.pk) - - else: - form = forms.SecretForm(instance=secret) - - return render(request, 'secrets/secret_edit.html', { - 'secret': secret, - 'form': form, - 'return_url': reverse('secrets:secret', kwargs={'pk': secret.pk}), - }) - - -class SecretDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'secrets.delete_secret' +class SecretDeleteView(ObjectDeleteView): queryset = Secret.objects.all() default_return_url = 'secrets:secret_list' class SecretBulkImportView(BulkImportView): - permission_required = 'secrets.add_secret' + queryset = Secret.objects.all() model_form = forms.SecretCSVForm table = tables.SecretTable template_name = 'secrets/secret_import.html' @@ -243,8 +198,7 @@ class SecretBulkImportView(BulkImportView): }) -class SecretBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'secrets.change_secret' +class SecretBulkEditView(BulkEditView): queryset = Secret.objects.prefetch_related('role', 'device') filterset = filters.SecretFilterSet table = tables.SecretTable @@ -252,8 +206,7 @@ class SecretBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'secrets:secret_list' -class SecretBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'secrets.delete_secret' +class SecretBulkDeleteView(BulkDeleteView): queryset = Secret.objects.prefetch_related('role', 'device') filterset = filters.SecretFilterSet table = tables.SecretTable diff --git a/netbox/templates/circuits/inc/circuit_termination.html b/netbox/templates/circuits/inc/circuit_termination.html index 8db715711..30d875657 100644 --- a/netbox/templates/circuits/inc/circuit_termination.html +++ b/netbox/templates/circuits/inc/circuit_termination.html @@ -90,7 +90,7 @@ IP Addressing {% if termination.connected_endpoint %} - {% for ip in termination.connected_endpoint.ip_addresses.all %} + {% for ip in termination.ip_addresses %} {% if not forloop.first %}
{% endif %} {{ ip }} ({{ ip.vrf|default:"Global" }}) {% empty %} diff --git a/netbox/templates/circuits/provider.html b/netbox/templates/circuits/provider.html index c02637e8e..42c322ce2 100644 --- a/netbox/templates/circuits/provider.html +++ b/netbox/templates/circuits/provider.html @@ -99,7 +99,7 @@ Circuits - {{ provider.circuits.count }} + {{ circuits_table.rows|length }} diff --git a/netbox/templates/dcim/cable.html b/netbox/templates/dcim/cable.html index e6a2fa008..91c7b1a94 100644 --- a/netbox/templates/dcim/cable.html +++ b/netbox/templates/dcim/cable.html @@ -81,6 +81,7 @@ + {% include 'extras/inc/tags_panel.html' with tags=cable.tags.all url='dcim:cable_list' %} {% plugin_left_page cable %}
diff --git a/netbox/templates/dcim/device.html b/netbox/templates/dcim/device.html index ef1a301e2..a42250a3d 100644 --- a/netbox/templates/dcim/device.html +++ b/netbox/templates/dcim/device.html @@ -101,7 +101,7 @@ Inventory {{ device.inventory_items.count }} - {% if perms.dcim.napalm_read %} + {% if perms.dcim.napalm_read_device %} {% if device.status != 'active' %} {% include 'dcim/inc/device_napalm_tabs.html' with disabled_message='Device must be in active status' %} {% elif not device.platform %} diff --git a/netbox/templates/dcim/inc/cable_form.html b/netbox/templates/dcim/inc/cable_form.html index 0799eb130..98eca17d2 100644 --- a/netbox/templates/dcim/inc/cable_form.html +++ b/netbox/templates/dcim/inc/cable_form.html @@ -10,10 +10,25 @@
{{ form.length }} + {% if form.length.errors %} +
    + {% for error in form.length.errors %} +
  • {{ error }}
  • + {% endfor %} +
+ {% endif %}
{{ form.length_unit }} + {% if form.length_unit.errors %} +
    + {% for error in form.length_unit.errors %} +
  • {{ error }}
  • + {% endfor %} +
+ {% endif %}
+ {% render_field form.tags %} diff --git a/netbox/templates/dcim/inc/devicetype_component_table.html b/netbox/templates/dcim/inc/devicetype_component_table.html index a83059980..010749b93 100644 --- a/netbox/templates/dcim/inc/devicetype_component_table.html +++ b/netbox/templates/dcim/inc/devicetype_component_table.html @@ -9,12 +9,12 @@ - + + + + + @@ -92,7 +92,7 @@ - + @@ -114,7 +114,7 @@ @@ -221,7 +221,7 @@ {% for member in interface.member_interfaces.all %}
{% if interface.device %}Device{% else %}Virtual Machine{% endif %}Device - {{ interface.parent }} + {{ interface.device }}
Name {{ interface.name }}
Label{{ interface.label|placeholder }}
Type {{ interface.get_type_display }}
MAC Address{{ interface.mac_address|placeholder }}{{ interface.mac_address|placeholder }}
802.1Q Mode
Device - {{ connected_interface.device }} + {{ connected_interface.device }}
- {{ member.parent }} + {{ member.device }} {{ member }} diff --git a/netbox/templates/dcim/interface_edit.html b/netbox/templates/dcim/interface_edit.html index a80b7c592..eaffe2bca 100644 --- a/netbox/templates/dcim/interface_edit.html +++ b/netbox/templates/dcim/interface_edit.html @@ -6,6 +6,7 @@
Interface
{% render_field form.name %} + {% render_field form.label %} {% render_field form.type %} {% render_field form.enabled %} {% render_field form.lag %} diff --git a/netbox/templates/dcim/powerpanel.html b/netbox/templates/dcim/powerpanel.html index 3ee8d80e0..90956d2a3 100644 --- a/netbox/templates/dcim/powerpanel.html +++ b/netbox/templates/dcim/powerpanel.html @@ -82,6 +82,7 @@
+ {% include 'extras/inc/tags_panel.html' with tags=powerpanel.tags.all url='dcim:powerpanel_list' %} {% plugin_left_page powerpanel %}
diff --git a/netbox/templates/dcim/rack.html b/netbox/templates/dcim/rack.html index 2c7452ba2..8d63a7095 100644 --- a/netbox/templates/dcim/rack.html +++ b/netbox/templates/dcim/rack.html @@ -138,7 +138,7 @@ Devices - {{ rack.devices.count }} + {{ device_count }} diff --git a/netbox/templates/dcim/rackreservation.html b/netbox/templates/dcim/rackreservation.html index d4bbbc97d..ab0fc0bba 100644 --- a/netbox/templates/dcim/rackreservation.html +++ b/netbox/templates/dcim/rackreservation.html @@ -124,6 +124,7 @@
+ {% include 'extras/inc/tags_panel.html' with tags=rackreservation.tags.all url='dcim:rackreservation_list' %} {% plugin_left_page rackreservation %}
diff --git a/netbox/templates/dcim/rackreservation_edit.html b/netbox/templates/dcim/rackreservation_edit.html index b2304974e..d6fa9cfcb 100644 --- a/netbox/templates/dcim/rackreservation_edit.html +++ b/netbox/templates/dcim/rackreservation_edit.html @@ -3,19 +3,22 @@ {% block form %}
-
{{ obj_type|capfirst }}
+
Rack Reservation
-
- -
-

{{ obj.rack }}

-
-
+ {% render_field form.site %} + {% render_field form.rack_group %} + {% render_field form.rack %} {% render_field form.units %} {% render_field form.user %} + {% render_field form.description %} + {% render_field form.tags %} +
+
+
+
Tenant Assignment
+
{% render_field form.tenant_group %} {% render_field form.tenant %} - {% render_field form.description %}
{% endblock %} diff --git a/netbox/templates/dcim/site.html b/netbox/templates/dcim/site.html index f5823f721..d6c21bf92 100644 --- a/netbox/templates/dcim/site.html +++ b/netbox/templates/dcim/site.html @@ -12,7 +12,7 @@
- {% for tag in tags %} + {% for tag in tags.unrestricted %} {% tag tag url %} {% empty %} No tags assigned diff --git a/netbox/templates/extras/tag.html b/netbox/templates/extras/tag.html index 0c20bcbdc..ff54a4800 100644 --- a/netbox/templates/extras/tag.html +++ b/netbox/templates/extras/tag.html @@ -85,7 +85,7 @@ Description - {{ tag.description }} + {{ tag.description|placeholder }}
diff --git a/netbox/templates/inc/nav_menu.html b/netbox/templates/inc/nav_menu.html index 765df31cc..4704ef613 100644 --- a/netbox/templates/inc/nav_menu.html +++ b/netbox/templates/inc/nav_menu.html @@ -70,6 +70,7 @@ {% if perms.dcim.add_rackreservation %}
+
{% endif %} @@ -101,6 +102,12 @@
  • + {% if perms.extras.add_tag %} +
    + + +
    + {% endif %} Tags @@ -365,6 +372,14 @@ {% endif %} Virtual Machines + + {% if perms.virtualization.add_vminterface %} +
    + +
    + {% endif %} + Interfaces +
  • diff --git a/netbox/templates/ipam/ipaddress.html b/netbox/templates/ipam/ipaddress.html index 6eba1a5e6..ff83061cf 100644 --- a/netbox/templates/ipam/ipaddress.html +++ b/netbox/templates/ipam/ipaddress.html @@ -120,8 +120,8 @@ Assignment - {% if ipaddress.interface %} - {{ ipaddress.interface.parent }} ({{ ipaddress.interface }}) + {% if ipaddress.assigned_object %} + {{ ipaddress.assigned_object.parent }} ({{ ipaddress.assigned_object }}) {% else %} {% endif %} @@ -132,8 +132,8 @@ {% if ipaddress.nat_inside %} {{ ipaddress.nat_inside }} - {% if ipaddress.nat_inside.interface %} - ({{ ipaddress.nat_inside.interface.parent }}) + {% if ipaddress.nat_inside.assigned_object %} + ({{ ipaddress.nat_inside.assigned_object.parent }}) {% endif %} {% else %} None diff --git a/netbox/templates/ipam/ipaddress_bulk_add.html b/netbox/templates/ipam/ipaddress_bulk_add.html index 5d4f4f7cb..bbb179fc8 100644 --- a/netbox/templates/ipam/ipaddress_bulk_add.html +++ b/netbox/templates/ipam/ipaddress_bulk_add.html @@ -26,6 +26,12 @@ {% render_field model_form.tenant %} +
    +
    Tags
    +
    + {% render_field model_form.tags %} +
    +
    {% if model_form.custom_fields %}
    Custom Fields
    diff --git a/netbox/templates/ipam/ipaddress_edit.html b/netbox/templates/ipam/ipaddress_edit.html index d8902595a..4e2706daf 100644 --- a/netbox/templates/ipam/ipaddress_edit.html +++ b/netbox/templates/ipam/ipaddress_edit.html @@ -28,25 +28,30 @@ {% render_field form.tenant %}
    - {% if obj.interface %} -
    -
    - Interface Assignment -
    -
    -
    - -
    -

    - {{ obj.interface.parent }} -

    +
    +
    + Interface Assignment +
    +
    + {% with vm_tab_active=obj.vminterface.exists %} + +
    +
    + {% render_field form.device %} + {% render_field form.interface %} +
    +
    + {% render_field form.virtual_machine %} + {% render_field form.vminterface %}
    - {% render_field form.interface %} - {% render_field form.primary_for_parent %} -
    + {% endwith %} + {% render_field form.primary_for_parent %}
    - {% endif %} +
    NAT IP (Inside)
    diff --git a/netbox/templates/ipam/prefix.html b/netbox/templates/ipam/prefix.html index 4620f6bf4..241cdd9a4 100644 --- a/netbox/templates/ipam/prefix.html +++ b/netbox/templates/ipam/prefix.html @@ -64,7 +64,7 @@ - {% if perms.ipam.view_ipaddress %} + {% if perms.ipam.view_ipaddress and prefix.status != 'container' %} diff --git a/netbox/templates/secrets/secret_edit.html b/netbox/templates/secrets/secret_edit.html index cb3935521..6893e2d14 100644 --- a/netbox/templates/secrets/secret_edit.html +++ b/netbox/templates/secrets/secret_edit.html @@ -9,7 +9,7 @@ {{ form.private_key }}
    -

    {% block title %}{% if secret.pk %}Editing {{ secret }}{% else %}Add a Secret{% endif %}{% endblock %}

    +

    {% block title %}{% if obj.pk %}Editing {{ obj }}{% else %}Add a Secret{% endif %}{% endblock %}

    {% if form.non_field_errors %}
    Errors
    @@ -30,17 +30,17 @@
    Secret Data
    - {% if secret.pk and secret|decryptable_by:request.user %} + {% if obj.pk and obj|decryptable_by:request.user %}
    -

    ********

    +

    ********

    - -
    @@ -69,9 +69,9 @@
    - {% if secret.pk %} + {% if obj.pk %} - Cancel + Cancel {% else %} diff --git a/netbox/templates/dcim/bulk_rename.html b/netbox/templates/utilities/obj_bulk_rename.html similarity index 100% rename from netbox/templates/dcim/bulk_rename.html rename to netbox/templates/utilities/obj_bulk_rename.html diff --git a/netbox/templates/utilities/obj_list.html b/netbox/templates/utilities/obj_list.html index 4cfa8b1ce..85ff050ed 100644 --- a/netbox/templates/utilities/obj_list.html +++ b/netbox/templates/utilities/obj_list.html @@ -5,7 +5,7 @@ {% block content %}
    {% block buttons %}{% endblock %} - {% if table_config_form %} + {% if request.user.is_authenticated and table_config_form %} {% endif %} {% if permissions.add and 'add' in action_buttons %} diff --git a/netbox/templates/virtualization/inc/vminterface.html b/netbox/templates/virtualization/inc/vminterface.html new file mode 100644 index 000000000..5410fba7a --- /dev/null +++ b/netbox/templates/virtualization/inc/vminterface.html @@ -0,0 +1,141 @@ +{% load helpers %} + + + {# Checkbox #} + {% if perms.virtualization.change_interface or perms.virtualization.delete_interface %} + + + + {% endif %} + + {# Name #} + + {{ iface }} + + + {# MAC address #} + + {{ iface.mac_address|default:"—" }} + + + {# MTU #} + {{ iface.mtu|default:"—" }} + + {# 802.1Q mode #} + {{ iface.get_mode_display|default:"—" }} + + {# Description/tags #} + + {% if iface.description %} + {{ iface.description }}
    + {% endif %} + {% for tag in iface.tags.all %} + {% tag tag %} + {% empty %} + {% if not iface.description %}—{% endif %} + {% endfor %} + + + {# Buttons #} + + {% if show_interface_graphs %} + + {% endif %} + {% if perms.ipam.add_ipaddress %} + + + + {% endif %} + {% if perms.virtualization.change_interface %} + + + + {% endif %} + {% if perms.virtualization.delete_interface %} + + + + {% endif %} + + + +{% with ipaddresses=iface.ip_addresses.all %} + {% if ipaddresses %} + + {# Placeholder #} + {% if perms.virtualization.change_interface or perms.virtualization.delete_interface %} + + {% endif %} + + {# IP addresses table #} + + + + + + + + + + + + {% for ip in iface.ip_addresses.all %} + + + {# IP address #} + + + {# Primary/status/role #} + + + {# VRF #} + + + {# Description #} + + + {# Buttons #} + + + + {% endfor %} +
    IP AddressStatus/RoleVRFDescription
    + {{ ip }} + + {% if virtualmachine.primary_ip4 == ip or virtualmachine.primary_ip6 == ip %} + Primary + {% endif %} + {{ ip.get_status_display }} + {% if ip.role %} + {{ ip.get_role_display }} + {% endif %} + + {% if ip.vrf %} + {{ ip.vrf.name }} + {% else %} + Global + {% endif %} + + {% if ip.description %} + {{ ip.description }} + {% else %} + + {% endif %} + + {% if perms.ipam.change_ipaddress %} + + + + {% endif %} + {% if perms.ipam.delete_ipaddress %} + + + + {% endif %} +
    + + + {% endif %} +{% endwith %} diff --git a/netbox/templates/virtualization/virtualmachine.html b/netbox/templates/virtualization/virtualmachine.html index ea8f4fedb..9cc206dac 100644 --- a/netbox/templates/virtualization/virtualmachine.html +++ b/netbox/templates/virtualization/virtualmachine.html @@ -248,7 +248,7 @@
    - {% if perms.dcim.change_interface or perms.dcim.delete_interface %} + {% if perms.virtualization.change_vminterface or perms.virtualization.delete_vminterface %}
    {% csrf_token %} @@ -268,22 +268,20 @@ - {% if perms.dcim.change_interface or perms.dcim.delete_interface %} + {% if perms.virtualization.change_vminterface or perms.virtualization.delete_vminterface %} {% endif %} - - + - - + {% for iface in interfaces %} - {% include 'dcim/inc/interface.html' with device=virtualmachine %} + {% include 'virtualization/inc/vminterface.html' %} {% empty %} @@ -291,24 +289,24 @@ {% endfor %}
    NameLAGDescriptionMAC Address MTU ModeCableConnectionDescription
    — No interfaces defined —
    - {% if perms.dcim.add_interface or perms.dcim.delete_interface %} + {% if perms.virtualization.add_vminterface or perms.virtualization.delete_vminterface %} {% endif %}
    - {% if perms.dcim.delete_interface %} + {% if perms.virtualization.delete_vminterface %} {% endif %}
    diff --git a/netbox/templates/virtualization/virtualmachine_component_add.html b/netbox/templates/virtualization/virtualmachine_component_add.html index 34a8f3c3d..aafefffa1 100644 --- a/netbox/templates/virtualization/virtualmachine_component_add.html +++ b/netbox/templates/virtualization/virtualmachine_component_add.html @@ -22,12 +22,6 @@ {{ component_type|bettertitle }}
    -
    - -
    -

    {{ parent }}

    -
    -
    {% render_form form %}
    diff --git a/netbox/templates/virtualization/virtualmachine_list.html b/netbox/templates/virtualization/virtualmachine_list.html index 74839b250..f8ee77626 100644 --- a/netbox/templates/virtualization/virtualmachine_list.html +++ b/netbox/templates/virtualization/virtualmachine_list.html @@ -7,7 +7,7 @@ Add Components
    {% endif %} diff --git a/netbox/templates/virtualization/vminterface.html b/netbox/templates/virtualization/vminterface.html new file mode 100644 index 000000000..8d46b52fd --- /dev/null +++ b/netbox/templates/virtualization/vminterface.html @@ -0,0 +1,100 @@ +{% extends 'base.html' %} +{% load helpers %} + +{% block header %} +
    +
    + +
    +
    +
    + {% if perms.virtualization.change_vminterface %} + + Edit + + {% endif %} + {% if perms.virtualization.delete_vminterface %} + + Delete + + {% endif %} +
    +

    {% block title %}{{ vminterface.virtual_machine }} / {{ vminterface.name }}{% endblock %}

    + +{% endblock %} + +{% block content %} +
    +
    +
    +
    + Interface +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Virtual Machine + {{ vminterface.virtual_machine }} +
    Name{{ vminterface.name }}
    Enabled + {% if vminterface.enabled %} + + {% else %} + + {% endif %} +
    Description{{ vminterface.description|placeholder }}
    MTU{{ vminterface.mtu|placeholder }}
    MAC Address{{ vminterface.mac_address|placeholder }}
    802.1Q Mode{{ vminterface.get_mode_display }}
    +
    +
    +
    + {% include 'extras/inc/tags_panel.html' with tags=vminterface.tags.all %} +
    +
    +
    +
    + {% include 'panel_table.html' with table=ipaddress_table heading="IP Addresses" %} +
    +
    +
    +
    + {% include 'panel_table.html' with table=vlan_table heading="VLANs" %} +
    +
    +{% endblock %} diff --git a/netbox/templates/virtualization/interface_edit.html b/netbox/templates/virtualization/vminterface_edit.html similarity index 91% rename from netbox/templates/virtualization/interface_edit.html rename to netbox/templates/virtualization/vminterface_edit.html index 437b960c9..6b0313284 100644 --- a/netbox/templates/virtualization/interface_edit.html +++ b/netbox/templates/virtualization/vminterface_edit.html @@ -21,7 +21,7 @@ {% block buttons %} {% if obj.pk %} - + {% else %} diff --git a/netbox/tenancy/api/serializers.py b/netbox/tenancy/api/serializers.py index 9c7a099e4..4454ac776 100644 --- a/netbox/tenancy/api/serializers.py +++ b/netbox/tenancy/api/serializers.py @@ -1,7 +1,7 @@ from rest_framework import serializers -from taggit_serializer.serializers import TaggitSerializer, TagListSerializerField from extras.api.customfields import CustomFieldModelSerializer +from extras.api.serializers import TaggedObjectSerializer from tenancy.models import Tenant, TenantGroup from utilities.api import ValidatedModelSerializer from .nested_serializers import * @@ -20,9 +20,8 @@ class TenantGroupSerializer(ValidatedModelSerializer): fields = ['id', 'name', 'slug', 'parent', 'description', 'tenant_count'] -class TenantSerializer(TaggitSerializer, CustomFieldModelSerializer): +class TenantSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): group = NestedTenantGroupSerializer(required=False) - tags = TagListSerializerField(required=False) circuit_count = serializers.IntegerField(read_only=True) device_count = serializers.IntegerField(read_only=True) ipaddress_count = serializers.IntegerField(read_only=True) diff --git a/netbox/tenancy/filters.py b/netbox/tenancy/filters.py index af5ee0b2c..42137d7ca 100644 --- a/netbox/tenancy/filters.py +++ b/netbox/tenancy/filters.py @@ -15,12 +15,12 @@ __all__ = ( class TenantGroupFilterSet(BaseFilterSet, NameSlugSearchFilterSet): parent_id = django_filters.ModelMultipleChoiceFilter( - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), label='Tenant group (ID)', ) parent = django_filters.ModelMultipleChoiceFilter( field_name='parent__slug', - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), to_field_name='slug', label='Tenant group group (slug)', ) @@ -36,13 +36,13 @@ class TenantFilterSet(BaseFilterSet, CustomFieldFilterSet, CreatedUpdatedFilterS label='Search', ) group_id = TreeNodeMultipleChoiceFilter( - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), field_name='group', lookup_expr='in', label='Tenant group (ID)', ) group = TreeNodeMultipleChoiceFilter( - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), field_name='group', lookup_expr='in', to_field_name='slug', @@ -70,24 +70,24 @@ class TenancyFilterSet(django_filters.FilterSet): An inheritable FilterSet for models which support Tenant assignment. """ tenant_group_id = TreeNodeMultipleChoiceFilter( - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), field_name='tenant__group', lookup_expr='in', label='Tenant Group (ID)', ) tenant_group = TreeNodeMultipleChoiceFilter( - queryset=TenantGroup.objects.all(), + queryset=TenantGroup.objects.unrestricted(), field_name='tenant__group', to_field_name='slug', lookup_expr='in', label='Tenant Group (slug)', ) tenant_id = django_filters.ModelMultipleChoiceFilter( - queryset=Tenant.objects.all(), + queryset=Tenant.objects.unrestricted(), label='Tenant (ID)', ) tenant = django_filters.ModelMultipleChoiceFilter( - queryset=Tenant.objects.all(), + queryset=Tenant.objects.unrestricted(), field_name='tenant__slug', to_field_name='slug', label='Tenant (slug)', diff --git a/netbox/tenancy/forms.py b/netbox/tenancy/forms.py index 700d88b1d..5bd0657b6 100644 --- a/netbox/tenancy/forms.py +++ b/netbox/tenancy/forms.py @@ -1,9 +1,9 @@ from django import forms -from taggit.forms import TagField from extras.forms import ( AddRemoveTagsForm, CustomFieldModelForm, CustomFieldBulkEditForm, CustomFieldFilterForm, CustomFieldModelCSVForm, ) +from extras.models import Tag from utilities.forms import ( APISelect, APISelectMultiple, BootstrapMixin, CommentField, CSVModelChoiceField, CSVModelForm, DynamicModelChoiceField, DynamicModelMultipleChoiceField, SlugField, TagFilterField, @@ -57,7 +57,8 @@ class TenantForm(BootstrapMixin, CustomFieldModelForm): required=False ) comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) diff --git a/netbox/tenancy/models.py b/netbox/tenancy/models.py index 077fb6ad1..2e415b965 100644 --- a/netbox/tenancy/models.py +++ b/netbox/tenancy/models.py @@ -7,6 +7,8 @@ from taggit.managers import TaggableManager from extras.models import CustomFieldModel, ObjectChange, TaggedItem from extras.utils import extras_features from utilities.models import ChangeLoggedModel +from utilities.mptt import TreeManager +from utilities.querysets import RestrictedQuerySet from utilities.utils import serialize_object @@ -40,6 +42,8 @@ class TenantGroup(MPTTModel, ChangeLoggedModel): blank=True ) + objects = TreeManager() + csv_headers = ['name', 'slug', 'parent', 'description'] class Meta: @@ -104,9 +108,10 @@ class Tenant(ChangeLoggedModel, CustomFieldModel): content_type_field='obj_type', object_id_field='obj_id' ) - tags = TaggableManager(through=TaggedItem) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'slug', 'group', 'description', 'comments'] clone_fields = [ 'group', 'description', diff --git a/netbox/tenancy/tests/test_api.py b/netbox/tenancy/tests/test_api.py index 8da3d7594..b06a8213a 100644 --- a/netbox/tenancy/tests/test_api.py +++ b/netbox/tenancy/tests/test_api.py @@ -1,8 +1,7 @@ from django.urls import reverse -from rest_framework import status from tenancy.models import Tenant, TenantGroup -from utilities.testing import APITestCase +from utilities.testing import APITestCase, APIViewTestCases class AppTest(APITestCase): @@ -15,235 +14,74 @@ class AppTest(APITestCase): self.assertEqual(response.status_code, 200) -class TenantGroupTest(APITestCase): +class TenantGroupTest(APIViewTestCases.APIViewTestCase): + model = TenantGroup + brief_fields = ['id', 'name', 'slug', 'tenant_count', 'url'] - def setUp(self): + @classmethod + def setUpTestData(cls): - super().setUp() - - self.parent_tenant_groups = ( - TenantGroup(name='Parent Tenant Group 1', slug='parent-tenant-group-1'), - TenantGroup(name='Parent Tenant Group 2', slug='parent-tenant-group-2'), - ) - for tenantgroup in self.parent_tenant_groups: - tenantgroup.save() - - self.tenant_groups = ( - TenantGroup(name='Tenant Group 1', slug='tenant-group-1', parent=self.parent_tenant_groups[0]), - TenantGroup(name='Tenant Group 2', slug='tenant-group-2', parent=self.parent_tenant_groups[0]), - TenantGroup(name='Tenant Group 3', slug='tenant-group-3', parent=self.parent_tenant_groups[0]), - ) - for tenantgroup in self.tenant_groups: - tenantgroup.save() - - def test_get_tenantgroup(self): - - url = reverse('tenancy-api:tenantgroup-detail', kwargs={'pk': self.tenant_groups[0].pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.tenant_groups[0].name) - - def test_list_tenantgroups(self): - - url = reverse('tenancy-api:tenantgroup-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 5) - - def test_list_tenantgroups_brief(self): - - url = reverse('tenancy-api:tenantgroup-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['id', 'name', 'slug', 'tenant_count', 'url'] + parent_tenant_groups = ( + TenantGroup.objects.create(name='Parent Tenant Group 1', slug='parent-tenant-group-1'), + TenantGroup.objects.create(name='Parent Tenant Group 2', slug='parent-tenant-group-2'), ) - def test_create_tenantgroup(self): + TenantGroup.objects.create(name='Tenant Group 1', slug='tenant-group-1', parent=parent_tenant_groups[0]) + TenantGroup.objects.create(name='Tenant Group 2', slug='tenant-group-2', parent=parent_tenant_groups[0]) + TenantGroup.objects.create(name='Tenant Group 3', slug='tenant-group-3', parent=parent_tenant_groups[0]) - data = { - 'name': 'Tenant Group 4', - 'slug': 'tenant-group-4', - 'parent': self.parent_tenant_groups[0].pk, - } - - url = reverse('tenancy-api:tenantgroup-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(TenantGroup.objects.count(), 6) - tenantgroup4 = TenantGroup.objects.get(pk=response.data['id']) - self.assertEqual(tenantgroup4.name, data['name']) - self.assertEqual(tenantgroup4.slug, data['slug']) - self.assertEqual(tenantgroup4.parent_id, data['parent']) - - def test_create_tenantgroup_bulk(self): - - data = [ + cls.create_data = [ { 'name': 'Tenant Group 4', 'slug': 'tenant-group-4', - 'parent': self.parent_tenant_groups[0].pk, + 'parent': parent_tenant_groups[1].pk, }, { 'name': 'Tenant Group 5', 'slug': 'tenant-group-5', - 'parent': self.parent_tenant_groups[0].pk, + 'parent': parent_tenant_groups[1].pk, }, { 'name': 'Tenant Group 6', 'slug': 'tenant-group-6', - 'parent': self.parent_tenant_groups[0].pk, + 'parent': parent_tenant_groups[1].pk, }, ] - url = reverse('tenancy-api:tenantgroup-list') - response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(TenantGroup.objects.count(), 8) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) +class TenantTest(APIViewTestCases.APIViewTestCase): + model = Tenant + brief_fields = ['id', 'name', 'slug', 'url'] - def test_update_tenantgroup(self): + @classmethod + def setUpTestData(cls): - data = { - 'name': 'Tenant Group X', - 'slug': 'tenant-group-x', - 'parent': self.parent_tenant_groups[1].pk, - } - - url = reverse('tenancy-api:tenantgroup-detail', kwargs={'pk': self.tenant_groups[0].pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(TenantGroup.objects.count(), 5) - tenantgroup1 = TenantGroup.objects.get(pk=response.data['id']) - self.assertEqual(tenantgroup1.name, data['name']) - self.assertEqual(tenantgroup1.slug, data['slug']) - self.assertEqual(tenantgroup1.parent_id, data['parent']) - - def test_delete_tenantgroup(self): - - url = reverse('tenancy-api:tenantgroup-detail', kwargs={'pk': self.tenant_groups[0].pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(TenantGroup.objects.count(), 4) - - -class TenantTest(APITestCase): - - def setUp(self): - - super().setUp() - - self.tenant_groups = ( - TenantGroup(name='Tenant Group 1', slug='tenant-group-1'), - TenantGroup(name='Tenant Group 2', slug='tenant-group-2'), - ) - for tenantgroup in self.tenant_groups: - tenantgroup.save() - - self.tenants = ( - Tenant(name='Test Tenant 1', slug='test-tenant-1', group=self.tenant_groups[0]), - Tenant(name='Test Tenant 2', slug='test-tenant-2', group=self.tenant_groups[0]), - Tenant(name='Test Tenant 3', slug='test-tenant-3', group=self.tenant_groups[0]), - ) - Tenant.objects.bulk_create(self.tenants) - - def test_get_tenant(self): - - url = reverse('tenancy-api:tenant-detail', kwargs={'pk': self.tenants[0].pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.tenants[0].name) - - def test_list_tenants(self): - - url = reverse('tenancy-api:tenant-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_tenants_brief(self): - - url = reverse('tenancy-api:tenant-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['id', 'name', 'slug', 'url'] + tenant_groups = ( + TenantGroup.objects.create(name='Tenant Group 1', slug='tenant-group-1'), + TenantGroup.objects.create(name='Tenant Group 2', slug='tenant-group-2'), ) - def test_create_tenant(self): + tenants = ( + Tenant(name='Tenant 1', slug='tenant-1', group=tenant_groups[0]), + Tenant(name='Tenant 2', slug='tenant-2', group=tenant_groups[0]), + Tenant(name='Tenant 3', slug='tenant-3', group=tenant_groups[0]), + ) + Tenant.objects.bulk_create(tenants) - data = { - 'name': 'Test Tenant 4', - 'slug': 'test-tenant-4', - 'group': self.tenant_groups[0].pk, - } - - url = reverse('tenancy-api:tenant-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Tenant.objects.count(), 4) - tenant4 = Tenant.objects.get(pk=response.data['id']) - self.assertEqual(tenant4.name, data['name']) - self.assertEqual(tenant4.slug, data['slug']) - self.assertEqual(tenant4.group_id, data['group']) - - def test_create_tenant_bulk(self): - - data = [ + cls.create_data = [ { - 'name': 'Test Tenant 4', - 'slug': 'test-tenant-4', + 'name': 'Tenant 4', + 'slug': 'tenant-4', + 'group': tenant_groups[1].pk, }, { - 'name': 'Test Tenant 5', - 'slug': 'test-tenant-5', + 'name': 'Tenant 5', + 'slug': 'tenant-5', + 'group': tenant_groups[1].pk, }, { - 'name': 'Test Tenant 6', - 'slug': 'test-tenant-6', + 'name': 'Tenant 6', + 'slug': 'tenant-6', + 'group': tenant_groups[1].pk, }, ] - - url = reverse('tenancy-api:tenant-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Tenant.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) - - def test_update_tenant(self): - - data = { - 'name': 'Test Tenant X', - 'slug': 'test-tenant-x', - 'group': self.tenant_groups[1].pk, - } - - url = reverse('tenancy-api:tenant-detail', kwargs={'pk': self.tenants[0].pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Tenant.objects.count(), 3) - tenant1 = Tenant.objects.get(pk=response.data['id']) - self.assertEqual(tenant1.name, data['name']) - self.assertEqual(tenant1.slug, data['slug']) - self.assertEqual(tenant1.group_id, data['group']) - - def test_delete_tenant(self): - - url = reverse('tenancy-api:tenant-detail', kwargs={'pk': self.tenants[0].pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Tenant.objects.count(), 2) diff --git a/netbox/tenancy/tests/test_views.py b/netbox/tenancy/tests/test_views.py index ca2c2633f..5b88b84cf 100644 --- a/netbox/tenancy/tests/test_views.py +++ b/netbox/tenancy/tests/test_views.py @@ -49,13 +49,15 @@ class TenantTestCase(ViewTestCases.PrimaryObjectViewTestCase): Tenant(name='Tenant 3', slug='tenant-3', group=tenant_groups[0]), ]) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'name': 'Tenant X', 'slug': 'tenant-x', 'group': tenant_groups[1].pk, 'description': 'A new tenant', 'comments': 'Some comments', - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } cls.csv_data = ( diff --git a/netbox/tenancy/urls.py b/netbox/tenancy/urls.py index 0218a5674..4c65ce4e8 100644 --- a/netbox/tenancy/urls.py +++ b/netbox/tenancy/urls.py @@ -9,7 +9,7 @@ urlpatterns = [ # Tenant groups path('tenant-groups/', views.TenantGroupListView.as_view(), name='tenantgroup_list'), - path('tenant-groups/add/', views.TenantGroupCreateView.as_view(), name='tenantgroup_add'), + path('tenant-groups/add/', views.TenantGroupEditView.as_view(), name='tenantgroup_add'), path('tenant-groups/import/', views.TenantGroupBulkImportView.as_view(), name='tenantgroup_import'), path('tenant-groups/delete/', views.TenantGroupBulkDeleteView.as_view(), name='tenantgroup_bulk_delete'), path('tenant-groups//edit/', views.TenantGroupEditView.as_view(), name='tenantgroup_edit'), @@ -17,7 +17,7 @@ urlpatterns = [ # Tenants path('tenants/', views.TenantListView.as_view(), name='tenant_list'), - path('tenants/add/', views.TenantCreateView.as_view(), name='tenant_add'), + path('tenants/add/', views.TenantEditView.as_view(), name='tenant_add'), path('tenants/import/', views.TenantBulkImportView.as_view(), name='tenant_import'), path('tenants/edit/', views.TenantBulkEditView.as_view(), name='tenant_bulk_edit'), path('tenants/delete/', views.TenantBulkDeleteView.as_view(), name='tenant_bulk_delete'), diff --git a/netbox/tenancy/views.py b/netbox/tenancy/views.py index 2af44094f..a82b231f5 100644 --- a/netbox/tenancy/views.py +++ b/netbox/tenancy/views.py @@ -1,13 +1,11 @@ -from django.contrib.auth.mixins import PermissionRequiredMixin from django.db.models import Count from django.shortcuts import get_object_or_404, render -from django.views.generic import View from circuits.models import Circuit from dcim.models import Site, Rack, Device, RackReservation from ipam.models import IPAddress, Prefix, VLAN, VRF from utilities.views import ( - BulkDeleteView, BulkEditView, BulkImportView, ObjectDeleteView, ObjectEditView, ObjectListView, + BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView, ) from virtualization.models import VirtualMachine, Cluster from . import filters, forms, tables @@ -18,8 +16,7 @@ from .models import Tenant, TenantGroup # Tenant groups # -class TenantGroupListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'tenancy.view_tenantgroup' +class TenantGroupListView(ObjectListView): queryset = TenantGroup.objects.add_related_count( TenantGroup.objects.all(), Tenant, @@ -30,26 +27,20 @@ class TenantGroupListView(PermissionRequiredMixin, ObjectListView): table = tables.TenantGroupTable -class TenantGroupCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'tenancy.add_tenantgroup' +class TenantGroupEditView(ObjectEditView): queryset = TenantGroup.objects.all() model_form = forms.TenantGroupForm default_return_url = 'tenancy:tenantgroup_list' -class TenantGroupEditView(TenantGroupCreateView): - permission_required = 'tenancy.change_tenantgroup' - - -class TenantGroupBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'tenancy.add_tenantgroup' +class TenantGroupBulkImportView(BulkImportView): + queryset = TenantGroup.objects.all() model_form = forms.TenantGroupCSVForm table = tables.TenantGroupTable default_return_url = 'tenancy:tenantgroup_list' -class TenantGroupBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'tenancy.delete_tenantgroup' +class TenantGroupBulkDeleteView(BulkDeleteView): queryset = TenantGroup.objects.annotate(tenant_count=Count('tenants')) table = tables.TenantGroupTable default_return_url = 'tenancy:tenantgroup_list' @@ -59,32 +50,31 @@ class TenantGroupBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Tenants # -class TenantListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'tenancy.view_tenant' +class TenantListView(ObjectListView): queryset = Tenant.objects.prefetch_related('group') filterset = filters.TenantFilterSet filterset_form = forms.TenantFilterForm table = tables.TenantTable -class TenantView(PermissionRequiredMixin, View): - permission_required = 'tenancy.view_tenant' +class TenantView(ObjectView): + queryset = Tenant.objects.prefetch_related('group') def get(self, request, slug): - tenant = get_object_or_404(Tenant, slug=slug) + tenant = get_object_or_404(self.queryset, slug=slug) stats = { - 'site_count': Site.objects.filter(tenant=tenant).count(), - 'rack_count': Rack.objects.filter(tenant=tenant).count(), - 'rackreservation_count': RackReservation.objects.filter(tenant=tenant).count(), - 'device_count': Device.objects.filter(tenant=tenant).count(), - 'vrf_count': VRF.objects.filter(tenant=tenant).count(), - 'prefix_count': Prefix.objects.filter(tenant=tenant).count(), - 'ipaddress_count': IPAddress.objects.filter(tenant=tenant).count(), - 'vlan_count': VLAN.objects.filter(tenant=tenant).count(), - 'circuit_count': Circuit.objects.filter(tenant=tenant).count(), - 'virtualmachine_count': VirtualMachine.objects.filter(tenant=tenant).count(), - 'cluster_count': Cluster.objects.filter(tenant=tenant).count(), + 'site_count': Site.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'rack_count': Rack.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'rackreservation_count': RackReservation.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'device_count': Device.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'vrf_count': VRF.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'prefix_count': Prefix.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'ipaddress_count': IPAddress.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'vlan_count': VLAN.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'circuit_count': Circuit.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'virtualmachine_count': VirtualMachine.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), + 'cluster_count': Cluster.objects.restrict(request.user, 'view').filter(tenant=tenant).count(), } return render(request, 'tenancy/tenant.html', { @@ -93,33 +83,26 @@ class TenantView(PermissionRequiredMixin, View): }) -class TenantCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'tenancy.add_tenant' +class TenantEditView(ObjectEditView): queryset = Tenant.objects.all() model_form = forms.TenantForm template_name = 'tenancy/tenant_edit.html' default_return_url = 'tenancy:tenant_list' -class TenantEditView(TenantCreateView): - permission_required = 'tenancy.change_tenant' - - -class TenantDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'tenancy.delete_tenant' +class TenantDeleteView(ObjectDeleteView): queryset = Tenant.objects.all() default_return_url = 'tenancy:tenant_list' -class TenantBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'tenancy.add_tenant' +class TenantBulkImportView(BulkImportView): + queryset = Tenant.objects.all() model_form = forms.TenantCSVForm table = tables.TenantTable default_return_url = 'tenancy:tenant_list' -class TenantBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'tenancy.change_tenant' +class TenantBulkEditView(BulkEditView): queryset = Tenant.objects.prefetch_related('group') filterset = filters.TenantFilterSet table = tables.TenantTable @@ -127,8 +110,7 @@ class TenantBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'tenancy:tenant_list' -class TenantBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'tenancy.delete_tenant' +class TenantBulkDeleteView(BulkDeleteView): queryset = Tenant.objects.prefetch_related('group') filterset = filters.TenantFilterSet table = tables.TenantTable diff --git a/netbox/users/admin.py b/netbox/users/admin.py index 42e651712..cc7a1b379 100644 --- a/netbox/users/admin.py +++ b/netbox/users/admin.py @@ -1,12 +1,31 @@ from django import forms from django.contrib import admin from django.contrib.auth.admin import UserAdmin as UserAdmin_ -from django.contrib.auth.models import User +from django.contrib.auth.models import Group as StockGroup, User as StockUser +from django.core.exceptions import FieldError, ValidationError -from .models import Token, UserConfig +from extras.admin import order_content_types +from .models import AdminGroup, AdminUser, ObjectPermission, Token, UserConfig -# Unregister the built-in UserAdmin so that we can use our custom admin view below -admin.site.unregister(User) + +# +# Users & groups +# + +# Unregister the built-in GroupAdmin and UserAdmin classes so that we can use our custom admin classes below +admin.site.unregister(StockGroup) +admin.site.unregister(StockUser) + + +@admin.register(AdminGroup) +class GroupAdmin(admin.ModelAdmin): + fields = ('name',) + list_display = ('name', 'user_count') + ordering = ('name',) + search_fields = ('name',) + + def user_count(self, obj): + return obj.user_set.count() class UserConfigInline(admin.TabularInline): @@ -16,14 +35,48 @@ class UserConfigInline(admin.TabularInline): verbose_name = 'Preferences' -@admin.register(User) +class ObjectPermissionInline(admin.TabularInline): + model = AdminUser.object_permissions.through + fields = ['object_types', 'actions', 'constraints'] + readonly_fields = fields + extra = 0 + verbose_name = 'Permission' + + def object_types(self, instance): + return ', '.join(instance.objectpermission.object_types.values_list('model', flat=True)) + + def actions(self, instance): + return ', '.join(instance.objectpermission.actions) + + def constraints(self, instance): + return instance.objectpermission.constraints + + def has_add_permission(self, request, obj): + # Don't allow the creation of new ObjectPermission assignments via this form + return False + + +@admin.register(AdminUser) class UserAdmin(UserAdmin_): list_display = [ 'username', 'email', 'first_name', 'last_name', 'is_superuser', 'is_staff', 'is_active' ] - inlines = (UserConfigInline,) + fieldsets = ( + (None, {'fields': ('username', 'password', 'first_name', 'last_name', 'email')}), + ('Groups', {'fields': ('groups',)}), + ('Permissions', { + 'fields': ('is_active', 'is_staff', 'is_superuser'), + }), + ('Important dates', {'fields': ('last_login', 'date_joined')}), + ) + inlines = [ObjectPermissionInline, UserConfigInline] + filter_horizontal = ('groups',) +# +# REST API tokens +# + class TokenAdminForm(forms.ModelForm): key = forms.CharField( required=False, @@ -43,3 +96,115 @@ class TokenAdmin(admin.ModelAdmin): list_display = [ 'key', 'user', 'created', 'expires', 'write_enabled', 'description' ] + + +# +# Permissions +# + +class ObjectPermissionForm(forms.ModelForm): + can_view = forms.BooleanField(required=False) + can_add = forms.BooleanField(required=False) + can_change = forms.BooleanField(required=False) + can_delete = forms.BooleanField(required=False) + + class Meta: + model = ObjectPermission + exclude = [] + help_texts = { + 'actions': 'Actions granted in addition to those listed above', + 'constraints': 'JSON expression of a queryset filter that will return only permitted objects. Leave null ' + 'to match all objects of this type.' + } + labels = { + 'actions': 'Additional actions' + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Make the actions field optional since the admin form uses it only for non-CRUD actions + self.fields['actions'].required = False + + # Format ContentType choices + order_content_types(self.fields['object_types']) + self.fields['object_types'].choices.insert(0, ('', '---------')) + + # Order group and user fields + self.fields['groups'].queryset = self.fields['groups'].queryset.order_by('name') + self.fields['users'].queryset = self.fields['users'].queryset.order_by('username') + + # Check the appropriate checkboxes when editing an existing ObjectPermission + if self.instance.pk: + for action in ['view', 'add', 'change', 'delete']: + if action in self.instance.actions: + self.fields[f'can_{action}'].initial = True + self.instance.actions.remove(action) + + def clean(self): + object_types = self.cleaned_data['object_types'] + constraints = self.cleaned_data['constraints'] + + # Append any of the selected CRUD checkboxes to the actions list + if not self.cleaned_data.get('actions'): + self.cleaned_data['actions'] = list() + for action in ['view', 'add', 'change', 'delete']: + if self.cleaned_data[f'can_{action}'] and action not in self.cleaned_data['actions']: + self.cleaned_data['actions'].append(action) + + # At least one action must be specified + if not self.cleaned_data['actions']: + raise ValidationError("At least one action must be selected.") + + # Validate the specified model constraints by attempting to execute a query. We don't care whether the query + # returns anything; we just want to make sure the specified constraints are valid. + if constraints: + for ct in object_types: + model = ct.model_class() + try: + model.objects.filter(**constraints).exists() + except FieldError as e: + raise ValidationError({ + 'constraints': f'Invalid filter for {model}: {e}' + }) + + +@admin.register(ObjectPermission) +class ObjectPermissionAdmin(admin.ModelAdmin): + fieldsets = ( + ('Objects', { + 'fields': ('object_types',) + }), + ('Assignment', { + 'fields': ('groups', 'users') + }), + ('Actions', { + 'fields': (('can_view', 'can_add', 'can_change', 'can_delete'), 'actions') + }), + ('Constraints', { + 'fields': ('constraints',) + }), + ) + filter_horizontal = ('object_types', 'groups', 'users') + form = ObjectPermissionForm + list_display = [ + 'list_models', 'list_users', 'list_groups', 'actions', 'constraints', + ] + list_filter = [ + 'groups', 'users' + ] + + def get_queryset(self, request): + return super().get_queryset(request).prefetch_related('object_types', 'users', 'groups') + + def list_models(self, obj): + return ', '.join([f"{ct}" for ct in obj.object_types.all()]) + list_models.short_description = 'Models' + + def list_users(self, obj): + return ', '.join([u.username for u in obj.users.all()]) + list_users.short_description = 'Users' + + def list_groups(self, obj): + return ', '.join([g.name for g in obj.groups.all()]) + list_groups.short_description = 'Groups' diff --git a/netbox/users/api/nested_serializers.py b/netbox/users/api/nested_serializers.py index d1b649713..f6e5cefbf 100644 --- a/netbox/users/api/nested_serializers.py +++ b/netbox/users/api/nested_serializers.py @@ -1,18 +1,45 @@ -from django.contrib.auth.models import User +from django.contrib.auth.models import Group, User +from django.contrib.contenttypes.models import ContentType +from rest_framework import serializers -from utilities.api import WritableNestedSerializer +from users.models import ObjectPermission +from utilities.api import ContentTypeField, WritableNestedSerializer -_all_ = [ +__all__ = [ + 'NestedGroupSerializer', + 'NestedObjectPermissionSerializer', 'NestedUserSerializer', ] -# -# Users -# +class NestedGroupSerializer(WritableNestedSerializer): + + class Meta: + model = Group + fields = ['id', 'name'] + class NestedUserSerializer(WritableNestedSerializer): class Meta: model = User fields = ['id', 'username'] + + +class NestedObjectPermissionSerializer(WritableNestedSerializer): + object_types = ContentTypeField( + queryset=ContentType.objects.all(), + many=True + ) + groups = serializers.SerializerMethodField(read_only=True) + users = serializers.SerializerMethodField(read_only=True) + + class Meta: + model = ObjectPermission + fields = ['id', 'object_types', 'groups', 'users', 'actions'] + + def get_groups(self, obj): + return [g.name for g in obj.groups.all()] + + def get_users(self, obj): + return [u.username for u in obj.users.all()] diff --git a/netbox/users/api/serializers.py b/netbox/users/api/serializers.py index 86d350e69..052567e47 100644 --- a/netbox/users/api/serializers.py +++ b/netbox/users/api/serializers.py @@ -1,4 +1,29 @@ +from django.contrib.auth.models import Group, User +from django.contrib.contenttypes.models import ContentType + +from users.models import ObjectPermission +from utilities.api import ContentTypeField, SerializedPKRelatedField, ValidatedModelSerializer from .nested_serializers import * -# Placeholder for future serializers +class ObjectPermissionSerializer(ValidatedModelSerializer): + object_types = ContentTypeField( + queryset=ContentType.objects.all(), + many=True + ) + groups = SerializedPKRelatedField( + queryset=Group.objects.all(), + serializer=NestedGroupSerializer, + required=False, + many=True + ) + users = SerializedPKRelatedField( + queryset=User.objects.all(), + serializer=NestedUserSerializer, + required=False, + many=True + ) + + class Meta: + model = ObjectPermission + fields = ('id', 'object_types', 'groups', 'users', 'actions', 'constraints') diff --git a/netbox/users/api/urls.py b/netbox/users/api/urls.py new file mode 100644 index 000000000..fffea5968 --- /dev/null +++ b/netbox/users/api/urls.py @@ -0,0 +1,21 @@ +from rest_framework import routers + +from . import views + + +class UsersRootView(routers.APIRootView): + """ + Users API root view + """ + def get_view_name(self): + return 'Users' + + +router = routers.DefaultRouter() +router.APIRootView = UsersRootView + +# Permissions +router.register('permissions', views.ObjectPermissionViewSet) + +app_name = 'users-api' +urlpatterns = router.urls diff --git a/netbox/users/api/views.py b/netbox/users/api/views.py new file mode 100644 index 000000000..74b315b44 --- /dev/null +++ b/netbox/users/api/views.py @@ -0,0 +1,14 @@ +from utilities.api import ModelViewSet +from . import serializers + +from users.models import ObjectPermission + + +# +# ObjectPermissions +# + +class ObjectPermissionViewSet(ModelViewSet): + queryset = ObjectPermission.objects.prefetch_related('object_types', 'groups', 'users') + serializer_class = serializers.ObjectPermissionSerializer + # filterset_class = filters.ObjectPermissionFilterSet diff --git a/netbox/users/migrations/0007_proxy_group_user.py b/netbox/users/migrations/0007_proxy_group_user.py new file mode 100644 index 000000000..2aec9e425 --- /dev/null +++ b/netbox/users/migrations/0007_proxy_group_user.py @@ -0,0 +1,46 @@ +# Generated by Django 3.0.6 on 2020-05-29 14:30 + +import django.contrib.auth.models +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('auth', '0011_update_proxy_permissions'), + ('users', '0006_create_userconfigs'), + ] + + operations = [ + migrations.CreateModel( + name='AdminGroup', + fields=[ + ], + options={ + 'proxy': True, + 'indexes': [], + 'constraints': [], + 'verbose_name': 'Group', + }, + bases=('auth.group',), + managers=[ + ('objects', django.contrib.auth.models.GroupManager()), + ], + ), + migrations.CreateModel( + name='AdminUser', + fields=[ + ], + options={ + 'proxy': True, + 'indexes': [], + 'constraints': [], + 'verbose_name': 'User', + }, + bases=('auth.user',), + managers=[ + ('objects', django.contrib.auth.models.UserManager()), + ], + ), + ] diff --git a/netbox/users/migrations/0008_objectpermission.py b/netbox/users/migrations/0008_objectpermission.py new file mode 100644 index 000000000..3f16e1ee8 --- /dev/null +++ b/netbox/users/migrations/0008_objectpermission.py @@ -0,0 +1,31 @@ +from django.conf import settings +import django.contrib.postgres.fields +import django.contrib.postgres.fields.jsonb +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('auth', '0011_update_proxy_permissions'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('users', '0007_proxy_group_user'), + ] + + operations = [ + migrations.CreateModel( + name='ObjectPermission', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False)), + ('constraints', django.contrib.postgres.fields.jsonb.JSONField(blank=True, null=True)), + ('actions', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=30), size=None)), + ('object_types', models.ManyToManyField(limit_choices_to={'app_label__in': ['circuits', 'dcim', 'extras', 'ipam', 'secrets', 'tenancy', 'virtualization']}, related_name='object_permissions', to='contenttypes.ContentType')), + ('groups', models.ManyToManyField(blank=True, related_name='object_permissions', to='auth.Group')), + ('users', models.ManyToManyField(blank=True, related_name='object_permissions', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'Permission', + }, + ), + ] diff --git a/netbox/users/migrations/0009_replicate_permissions.py b/netbox/users/migrations/0009_replicate_permissions.py new file mode 100644 index 000000000..a5d28beac --- /dev/null +++ b/netbox/users/migrations/0009_replicate_permissions.py @@ -0,0 +1,47 @@ +from django.db import migrations + + +ACTIONS = ['view', 'add', 'change', 'delete'] + + +def replicate_permissions(apps, schema_editor): + """ + Replicate all Permission assignments as ObjectPermissions. + """ + Permission = apps.get_model('auth', 'Permission') + ObjectPermission = apps.get_model('users', 'ObjectPermission') + + # TODO: Optimize this iteration so that ObjectPermissions with identical sets of users and groups + # are combined into a single ObjectPermission instance. + for perm in Permission.objects.all(): + if perm.codename.split('_')[0] in ACTIONS: + action = perm.codename.split('_')[0] + elif perm.codename == 'activate_userkey': + action = 'change' + elif perm.codename == 'run_script': + action = 'run' + else: + action = perm.codename + + if perm.group_set.exists() or perm.user_set.exists(): + obj_perm = ObjectPermission(actions=[action]) + obj_perm.save() + obj_perm.object_types.add(perm.content_type) + if perm.group_set.exists(): + obj_perm.groups.add(*list(perm.group_set.all())) + if perm.user_set.exists(): + obj_perm.users.add(*list(perm.user_set.all())) + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0008_objectpermission'), + ] + + operations = [ + migrations.RunPython( + code=replicate_permissions, + reverse_code=migrations.RunPython.noop + ) + ] diff --git a/netbox/users/models.py b/netbox/users/models.py index ea5762232..7987ccb7a 100644 --- a/netbox/users/models.py +++ b/netbox/users/models.py @@ -1,23 +1,52 @@ import binascii import os -from django.contrib.auth.models import User -from django.contrib.postgres.fields import JSONField +from django.contrib.auth.models import Group, User +from django.contrib.contenttypes.models import ContentType +from django.contrib.postgres.fields import ArrayField, JSONField from django.core.validators import MinLengthValidator from django.db import models from django.db.models.signals import post_save from django.dispatch import receiver from django.utils import timezone +from utilities.querysets import RestrictedQuerySet from utilities.utils import flatten_dict __all__ = ( + 'ObjectPermission', 'Token', 'UserConfig', ) +# +# Proxy models for admin +# + +class AdminGroup(Group): + """ + Proxy contrib.auth.models.Group for the admin UI + """ + class Meta: + verbose_name = 'Group' + proxy = True + + +class AdminUser(User): + """ + Proxy contrib.auth.models.User for the admin UI + """ + class Meta: + verbose_name = 'User' + proxy = True + + +# +# User preferences +# + class UserConfig(models.Model): """ This model stores arbitrary user-specific preferences in a JSON data structure. @@ -138,6 +167,10 @@ def create_userconfig(instance, created, **kwargs): UserConfig(user=instance).save() +# +# REST API +# + class Token(models.Model): """ An API token used for user authentication. This extends the stock model to allow each user to have multiple tokens. @@ -190,3 +223,53 @@ class Token(models.Model): if self.expires is None or timezone.now() < self.expires: return False return True + + +# +# Permissions +# + +class ObjectPermission(models.Model): + """ + A mapping of view, add, change, and/or delete permission for users and/or groups to an arbitrary set of objects + identified by ORM query parameters. + """ + object_types = models.ManyToManyField( + to=ContentType, + limit_choices_to={ + 'app_label__in': [ + 'circuits', 'dcim', 'extras', 'ipam', 'secrets', 'tenancy', 'virtualization', + ], + }, + related_name='object_permissions' + ) + groups = models.ManyToManyField( + to=Group, + blank=True, + related_name='object_permissions' + ) + users = models.ManyToManyField( + to=User, + blank=True, + related_name='object_permissions' + ) + actions = ArrayField( + base_field=models.CharField(max_length=30), + help_text="The list of actions granted by this permission" + ) + constraints = JSONField( + blank=True, + null=True, + help_text="Queryset filter matching the applicable objects of the selected type(s)" + ) + + objects = RestrictedQuerySet.as_manager() + + class Meta: + verbose_name = "Permission" + + def __str__(self): + return '{}: {}'.format( + ', '.join(self.object_types.values_list('model', flat=True)), + ', '.join(self.actions) + ) diff --git a/netbox/users/tests/test_api.py b/netbox/users/tests/test_api.py new file mode 100644 index 000000000..166473710 --- /dev/null +++ b/netbox/users/tests/test_api.py @@ -0,0 +1,74 @@ +from django.contrib.auth.models import Group, User +from django.contrib.contenttypes.models import ContentType +from django.urls import reverse + +from users.models import ObjectPermission +from utilities.testing import APIViewTestCases, APITestCase + + +class AppTest(APITestCase): + + def test_root(self): + + url = reverse('users-api:api-root') + response = self.client.get('{}?format=api'.format(url), **self.header) + + self.assertEqual(response.status_code, 200) + + +class ObjectPermissionTest(APIViewTestCases.APIViewTestCase): + model = ObjectPermission + brief_fields = ['actions', 'groups', 'id', 'object_types', 'users'] + + @classmethod + def setUpTestData(cls): + + groups = ( + Group(name='Group 1'), + Group(name='Group 2'), + Group(name='Group 3'), + ) + Group.objects.bulk_create(groups) + + users = ( + User(username='User 1', is_active=True), + User(username='User 2', is_active=True), + User(username='User 3', is_active=True), + ) + User.objects.bulk_create(users) + + object_type = ContentType.objects.get(app_label='dcim', model='device') + + for i in range(0, 3): + objectpermission = ObjectPermission( + actions=['view', 'add', 'change', 'delete'], + constraints={'name': f'TEST{i+1}'} + ) + objectpermission.save() + objectpermission.object_types.add(object_type) + objectpermission.groups.add(groups[i]) + objectpermission.users.add(users[i]) + + cls.create_data = [ + { + 'object_types': ['dcim.site'], + 'groups': [groups[0].pk], + 'users': [users[0].pk], + 'actions': ['view', 'add', 'change', 'delete'], + 'constraints': {'name': 'TEST4'}, + }, + { + 'object_types': ['dcim.site'], + 'groups': [groups[1].pk], + 'users': [users[1].pk], + 'actions': ['view', 'add', 'change', 'delete'], + 'constraints': {'name': 'TEST5'}, + }, + { + 'object_types': ['dcim.site'], + 'groups': [groups[2].pk], + 'users': [users[2].pk], + 'actions': ['view', 'add', 'change', 'delete'], + 'constraints': {'name': 'TEST6'}, + }, + ] diff --git a/netbox/users/views.py b/netbox/users/views.py index c3e366542..755232444 100644 --- a/netbox/users/views.py +++ b/netbox/users/views.py @@ -3,7 +3,7 @@ import logging from django.conf import settings from django.contrib import messages from django.contrib.auth import login as auth_login, logout as auth_logout, update_session_auth_hash -from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin +from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.models import update_last_login from django.contrib.auth.signals import user_logged_in from django.http import HttpResponseForbidden, HttpResponseRedirect @@ -50,7 +50,7 @@ class LoginView(View): logger.debug("Login form validation was successful") # Determine where to direct user after successful login - redirect_to = request.POST.get('next') + redirect_to = request.POST.get('next', reverse('home')) if redirect_to and not is_safe_url(url=redirect_to, allowed_hosts=request.get_host()): logger.warning(f"Ignoring unsafe 'next' URL passed to login form: {redirect_to}") redirect_to = reverse('home') @@ -320,8 +320,7 @@ class TokenEditView(LoginRequiredMixin, View): }) -class TokenDeleteView(PermissionRequiredMixin, View): - permission_required = 'users.delete_token' +class TokenDeleteView(LoginRequiredMixin, View): def get(self, request, pk): diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index 205055669..50401dfd1 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -4,17 +4,18 @@ from collections import OrderedDict import pytz from django.conf import settings from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist +from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist, PermissionDenied +from django.db import transaction from django.db.models import ManyToManyField, ProtectedError -from django.http import Http404 from django.urls import reverse from rest_framework.exceptions import APIException from rest_framework.permissions import BasePermission from rest_framework.relations import PrimaryKeyRelatedField, RelatedField from rest_framework.response import Response from rest_framework.serializers import Field, ModelSerializer, ValidationError -from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet +from rest_framework.viewsets import ModelViewSet as _ModelViewSet +from netbox.api import TokenPermissions from .utils import dict_to_filter_params, dynamic_import @@ -323,6 +324,27 @@ class ModelViewSet(_ModelViewSet): logger.debug(f"Using serializer {self.serializer_class}") return self.serializer_class + def initial(self, request, *args, **kwargs): + super().initial(request, *args, **kwargs) + + if not request.user.is_authenticated or request.user.is_superuser: + return + + # TODO: Reconcile this with TokenPermissions.perms_map + action = { + 'GET': 'view', + 'OPTIONS': None, + 'HEAD': 'view', + 'POST': 'add', + 'PUT': 'change', + 'PATCH': 'change', + 'DELETE': 'delete', + }[request.method] + + # Restrict the view's QuerySet to allow only the permitted objects + if action: + self.queryset = self.queryset.restrict(request.user, action) + def dispatch(self, request, *args, **kwargs): logger = logging.getLogger('netbox.api.views.ModelViewSet') @@ -341,34 +363,49 @@ class ModelViewSet(_ModelViewSet): **kwargs ) - def list(self, *args, **kwargs): + def _validate_objects(self, instance): """ - Call to super to allow for caching + Check that the provided instance or list of instances are matched by the current queryset. This confirms that + any newly created or modified objects abide by the attributes granted by any applicable ObjectPermissions. """ - return super().list(*args, **kwargs) - - def retrieve(self, *args, **kwargs): - """ - Call to super to allow for caching - """ - return super().retrieve(*args, **kwargs) - - # - # Logging - # + if type(instance) is list: + # Check that all instances are still included in the view's queryset + conforming_count = self.queryset.filter(pk__in=[obj.pk for obj in instance]).count() + if conforming_count != len(instance): + raise ObjectDoesNotExist + else: + # Check that the instance is matched by the view's queryset + self.queryset.get(pk=instance.pk) def perform_create(self, serializer): - model = serializer.child.Meta.model if hasattr(serializer, 'many') else serializer.Meta.model + model = self.queryset.model logger = logging.getLogger('netbox.api.views.ModelViewSet') logger.info(f"Creating new {model._meta.verbose_name}") - return super().perform_create(serializer) + + # Enforce object-level permissions on save() + try: + with transaction.atomic(): + instance = serializer.save() + self._validate_objects(instance) + except ObjectDoesNotExist: + raise PermissionDenied() def perform_update(self, serializer): + model = self.queryset.model logger = logging.getLogger('netbox.api.views.ModelViewSet') - logger.info(f"Updating {serializer.instance} (PK: {serializer.instance.pk})") - return super().perform_update(serializer) + logger.info(f"Updating {model._meta.verbose_name} {serializer.instance} (PK: {serializer.instance.pk})") + + # Enforce object-level permissions on save() + try: + with transaction.atomic(): + instance = serializer.save() + self._validate_objects(instance) + except ObjectDoesNotExist: + raise PermissionDenied() def perform_destroy(self, instance): + model = self.queryset.model logger = logging.getLogger('netbox.api.views.ModelViewSet') - logger.info(f"Deleting {instance} (PK: {instance.pk})") + logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})") + return super().perform_destroy(instance) diff --git a/netbox/utilities/auth_backends.py b/netbox/utilities/auth_backends.py deleted file mode 100644 index 6342bad2b..000000000 --- a/netbox/utilities/auth_backends.py +++ /dev/null @@ -1,73 +0,0 @@ -import logging - -from django.conf import settings -from django.contrib.auth.backends import ModelBackend, RemoteUserBackend as RemoteUserBackend_ -from django.contrib.auth.models import Group, Permission - - -class ViewExemptModelBackend(ModelBackend): - """ - Custom implementation of Django's stock ModelBackend which allows for the exemption of arbitrary models from view - permission enforcement. - """ - def has_perm(self, user_obj, perm, obj=None): - - # If this is a view permission, check whether the model has been exempted from enforcement - try: - app, codename = perm.split('.') - action, model = codename.split('_') - if action == 'view': - if ( - # All models are exempt from view permission enforcement - '*' in settings.EXEMPT_VIEW_PERMISSIONS - ) or ( - # This specific model is exempt from view permission enforcement - '{}.{}'.format(app, model) in settings.EXEMPT_VIEW_PERMISSIONS - ): - return True - except ValueError: - pass - - return super().has_perm(user_obj, perm, obj) - - -class RemoteUserBackend(ViewExemptModelBackend, RemoteUserBackend_): - """ - Custom implementation of Django's RemoteUserBackend which provides configuration hooks for basic customization. - """ - @property - def create_unknown_user(self): - return settings.REMOTE_AUTH_AUTO_CREATE_USER - - def configure_user(self, request, user): - logger = logging.getLogger('netbox.authentication.RemoteUserBackend') - - # Assign default groups to the user - group_list = [] - for name in settings.REMOTE_AUTH_DEFAULT_GROUPS: - try: - group_list.append(Group.objects.get(name=name)) - except Group.DoesNotExist: - logging.error(f"Could not assign group {name} to remotely-authenticated user {user}: Group not found") - if group_list: - user.groups.add(*group_list) - logger.debug(f"Assigned groups to remotely-authenticated user {user}: {group_list}") - - # Assign default permissions to the user - permissions_list = [] - for permission_name in settings.REMOTE_AUTH_DEFAULT_PERMISSIONS: - try: - app_label, codename = permission_name.split('.') - permissions_list.append( - Permission.objects.get(content_type__app_label=app_label, codename=codename) - ) - except (ValueError, Permission.DoesNotExist): - logging.error( - "Invalid permission name: '{permission_name}'. Permissions must be in the form " - "._. (Example: dcim.add_site)" - ) - if permissions_list: - user.user_permissions.add(*permissions_list) - logger.debug(f"Assigned permissions to remotely-authenticated user {user}: {permissions_list}") - - return user diff --git a/netbox/utilities/choices.py b/netbox/utilities/choices.py index aba64e63b..ce0929a8b 100644 --- a/netbox/utilities/choices.py +++ b/netbox/utilities/choices.py @@ -80,6 +80,70 @@ def unpack_grouped_choices(choices): return unpacked_choices +# +# Generic color choices +# + +class ColorChoices(ChoiceSet): + COLOR_DARK_RED = 'aa1409' + COLOR_RED = 'f44336' + COLOR_PINK = 'e91e63' + COLOR_ROSE = 'ffe4e1' + COLOR_FUCHSIA = 'ff66ff' + COLOR_PURPLE = '9c27b0' + COLOR_DARK_PURPLE = '673ab7' + COLOR_INDIGO = '3f51b5' + COLOR_BLUE = '2196f3' + COLOR_LIGHT_BLUE = '03a9f4' + COLOR_CYAN = '00bcd4' + COLOR_TEAL = '009688' + COLOR_AQUA = '00ffff' + COLOR_DARK_GREEN = '2f6a31' + COLOR_GREEN = '4caf50' + COLOR_LIGHT_GREEN = '8bc34a' + COLOR_LIME = 'cddc39' + COLOR_YELLOW = 'ffeb3b' + COLOR_AMBER = 'ffc107' + COLOR_ORANGE = 'ff9800' + COLOR_DARK_ORANGE = 'ff5722' + COLOR_BROWN = '795548' + COLOR_LIGHT_GREY = 'c0c0c0' + COLOR_GREY = '9e9e9e' + COLOR_DARK_GREY = '607d8b' + COLOR_BLACK = '111111' + COLOR_WHITE = 'ffffff' + + CHOICES = ( + (COLOR_DARK_RED, 'Dark red'), + (COLOR_RED, 'Red'), + (COLOR_PINK, 'Pink'), + (COLOR_ROSE, 'Rose'), + (COLOR_FUCHSIA, 'Fuchsia'), + (COLOR_PURPLE, 'Purple'), + (COLOR_DARK_PURPLE, 'Dark purple'), + (COLOR_INDIGO, 'Indigo'), + (COLOR_BLUE, 'Blue'), + (COLOR_LIGHT_BLUE, 'Light blue'), + (COLOR_CYAN, 'Cyan'), + (COLOR_TEAL, 'Teal'), + (COLOR_AQUA, 'Aqua'), + (COLOR_DARK_GREEN, 'Dark green'), + (COLOR_GREEN, 'Green'), + (COLOR_LIGHT_GREEN, 'Light green'), + (COLOR_LIME, 'Lime'), + (COLOR_YELLOW, 'Yellow'), + (COLOR_AMBER, 'Amber'), + (COLOR_ORANGE, 'Orange'), + (COLOR_DARK_ORANGE, 'Dark orange'), + (COLOR_BROWN, 'Brown'), + (COLOR_LIGHT_GREY, 'Light grey'), + (COLOR_GREY, 'Grey'), + (COLOR_DARK_GREY, 'Dark grey'), + (COLOR_BLACK, 'Black'), + (COLOR_WHITE, 'White'), + ) + + # # Button color choices # diff --git a/netbox/utilities/constants.py b/netbox/utilities/constants.py index bdcdeef11..9a3a7d028 100644 --- a/netbox/utilities/constants.py +++ b/netbox/utilities/constants.py @@ -1,34 +1,3 @@ -COLOR_CHOICES = ( - ('aa1409', 'Dark red'), - ('f44336', 'Red'), - ('e91e63', 'Pink'), - ('ffe4e1', 'Rose'), - ('ff66ff', 'Fuschia'), - ('9c27b0', 'Purple'), - ('673ab7', 'Dark purple'), - ('3f51b5', 'Indigo'), - ('2196f3', 'Blue'), - ('03a9f4', 'Light blue'), - ('00bcd4', 'Cyan'), - ('009688', 'Teal'), - ('00ffff', 'Aqua'), - ('2f6a31', 'Dark green'), - ('4caf50', 'Green'), - ('8bc34a', 'Light green'), - ('cddc39', 'Lime'), - ('ffeb3b', 'Yellow'), - ('ffc107', 'Amber'), - ('ff9800', 'Orange'), - ('ff5722', 'Dark orange'), - ('795548', 'Brown'), - ('c0c0c0', 'Light grey'), - ('9e9e9e', 'Grey'), - ('607d8b', 'Dark grey'), - ('111111', 'Black'), - ('ffffff', 'White'), -) - - # # Filter lookup expressions # diff --git a/netbox/utilities/custom_inspectors.py b/netbox/utilities/custom_inspectors.py index 2cbe1cfc5..38297838d 100644 --- a/netbox/utilities/custom_inspectors.py +++ b/netbox/utilities/custom_inspectors.py @@ -1,19 +1,12 @@ from django.contrib.postgres.fields import JSONField from drf_yasg import openapi -from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, FilterInspector, SwaggerAutoSchema +from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, SwaggerAutoSchema from drf_yasg.utils import get_serializer_ref_name from rest_framework.fields import ChoiceField from rest_framework.relations import ManyRelatedField -from taggit_serializer.serializers import TagListSerializerField -from dcim.api.serializers import InterfaceSerializer as DeviceInterfaceSerializer from extras.api.customfields import CustomFieldsSerializer from utilities.api import ChoiceField, SerializedPKRelatedField, WritableNestedSerializer -from virtualization.api.serializers import InterfaceSerializer as VirtualMachineInterfaceSerializer - -# this might be ugly, but it limits drf_yasg-specific code to this file -DeviceInterfaceSerializer.Meta.ref_name = 'DeviceInterface' -VirtualMachineInterfaceSerializer.Meta.ref_name = 'VirtualMachineInterface' class NetBoxSwaggerAutoSchema(SwaggerAutoSchema): @@ -56,19 +49,6 @@ class SerializedPKRelatedFieldInspector(FieldInspector): return NotHandled -class TagListFieldInspector(FieldInspector): - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): - SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) - if isinstance(field, TagListSerializerField): - child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references) - return SwaggerType( - type=openapi.TYPE_ARRAY, - items=child_schema, - ) - - return NotHandled - - class CustomChoiceFieldInspector(FieldInspector): def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # this returns a callable which extracts title, description and other stuff diff --git a/netbox/utilities/fields.py b/netbox/utilities/fields.py index 4eb19f539..a9b851def 100644 --- a/netbox/utilities/fields.py +++ b/netbox/utilities/fields.py @@ -68,6 +68,6 @@ class NaturalOrderingField(models.CharField): return ( self.name, 'utilities.fields.NaturalOrderingField', - ['target_field'], + [self.target_field], kwargs, ) diff --git a/netbox/utilities/forms.py b/netbox/utilities/forms.py index bfc783631..59e581ff4 100644 --- a/netbox/utilities/forms.py +++ b/netbox/utilities/forms.py @@ -7,6 +7,7 @@ import django_filters import yaml from django import forms from django.conf import settings +from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.forms.jsonb import JSONField as _JSONField, InvalidJSONInput from django.core.exceptions import MultipleObjectsReturned from django.db.models import Count @@ -14,8 +15,7 @@ from django.forms import BoundField from django.forms.models import fields_for_model from django.urls import reverse -from .choices import unpack_grouped_choices -from .constants import * +from .choices import ColorChoices, unpack_grouped_choices from .validators import EnhancedURLValidator NUMERIC_EXPANSION_PATTERN = r'\[((?:\d+[?:,-])+\d+)\]' @@ -163,7 +163,7 @@ class ColorSelect(forms.Select): option_template_name = 'widgets/colorselect_option.html' def __init__(self, *args, **kwargs): - kwargs['choices'] = add_blank_choice(COLOR_CHOICES) + kwargs['choices'] = add_blank_choice(ColorChoices) super().__init__(*args, **kwargs) self.attrs['class'] = 'netbox-select2-color-picker' @@ -244,24 +244,11 @@ class ContentTypeSelect(StaticSelect2): option_template_name = 'widgets/select_contenttype.html' -class ArrayFieldSelectMultiple(SelectWithDisabled, forms.SelectMultiple): - """ - MultiSelect widget for a SimpleArrayField. Choices must be populated on the widget. - """ - def __init__(self, *args, **kwargs): - self.delimiter = kwargs.pop('delimiter', ',') - super().__init__(*args, **kwargs) +class NumericArrayField(SimpleArrayField): - def optgroups(self, name, value, attrs=None): - # Split the delimited string of values into a list - if value: - value = value[0].split(self.delimiter) - return super().optgroups(name, value, attrs) - - def value_from_datadict(self, data, files, name): - # Condense the list of selected choices into a delimited string - data = super().value_from_datadict(data, files, name) - return self.delimiter.join(data) + def to_python(self, value): + value = ','.join([str(n) for n in parse_numeric_range(value)]) + return super().to_python(value) class APISelect(SelectWithDisabled): @@ -531,6 +518,8 @@ class ExpandableNameField(forms.CharField): """ def to_python(self, value): + if value is None: + return list() if re.search(ALPHANUMERIC_EXPANSION_PATTERN, value): return list(expand_alphanumeric_pattern(value)) return [value] @@ -596,8 +585,12 @@ class TagFilterField(forms.MultipleChoiceField): def __init__(self, model, *args, **kwargs): def get_choices(): - tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name') - return [(str(tag.slug), '{} ({})'.format(tag.name, tag.count)) for tag in tags] + tags = model.tags.annotate( + count=Count('extras_taggeditem_items') + ).order_by('name') + return [ + (str(tag.slug), '{} ({})'.format(tag.name, tag.count)) for tag in tags + ] # Choices are fetched each time the form is initialized super().__init__(label='Tags', choices=get_choices, required=False, *args, **kwargs) @@ -607,15 +600,18 @@ class DynamicModelChoiceMixin: filter = django_filters.ModelChoiceFilter widget = APISelect - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def _get_initial_value(self, initial_data, field_name): + return initial_data.get(field_name) def get_bound_field(self, form, field_name): bound_field = BoundField(form, self, field_name) + # Override initial() to allow passing multiple values + bound_field.initial = self._get_initial_value(form.initial, field_name) + # Modify the QuerySet of the field before we return it. Limit choices to any data already bound: Options # will be populated on-demand via the APISelect widget. - data = self.prepare_value(bound_field.data or bound_field.initial) + data = bound_field.value() if data: filter = self.filter(field_name=self.to_field_name or 'pk', queryset=self.queryset) self.queryset = filter.filter(self.queryset, data) @@ -648,12 +644,17 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip filter = django_filters.ModelMultipleChoiceFilter widget = APISelectMultiple + def _get_initial_value(self, initial_data, field_name): + # If a QueryDict has been passed as initial form data, get *all* listed values + if hasattr(initial_data, 'getlist'): + return initial_data.getlist(field_name) + return initial_data.get(field_name) + class LaxURLField(forms.URLField): """ - Modifies Django's built-in URLField in two ways: - 1) Allow any valid scheme per RFC 3986 section 3.1 - 2) Remove the requirement for fully-qualified domain names (e.g. http://myserver/ is valid) + Modifies Django's built-in URLField to remove the requirement for fully-qualified domain names + (e.g. http://myserver/ is valid) """ default_validators = [EnhancedURLValidator()] @@ -732,6 +733,30 @@ class BulkEditForm(forms.Form): self.nullable_fields = self.Meta.nullable_fields +class BulkRenameForm(forms.Form): + """ + An extendable form to be used for renaming objects in bulk. + """ + find = forms.CharField() + replace = forms.CharField() + use_regex = forms.BooleanField( + required=False, + initial=True, + label='Use regular expressions' + ) + + def clean(self): + + # Validate regular expression in "find" field + if self.cleaned_data['use_regex']: + try: + re.compile(self.cleaned_data['find']) + except re.error: + raise forms.ValidationError({ + 'find': "Invalid regular expression" + }) + + class CSVModelForm(forms.ModelForm): """ ModelForm used for the import of objects in CSV format. @@ -794,6 +819,31 @@ class ImportForm(BootstrapMixin, forms.Form): }) +class LabeledComponentForm(BootstrapMixin, forms.Form): + """ + Base form for adding label pattern validation to `Create` forms + """ + name_pattern = ExpandableNameField( + label='Name' + ) + label_pattern = ExpandableNameField( + label='Label', + required=False + ) + + def clean(self): + + # Validate that the number of components being created from both the name_pattern and label_pattern are equal + name_pattern_count = len(self.cleaned_data['name_pattern']) + label_pattern_count = len(self.cleaned_data['label_pattern']) + if label_pattern_count and name_pattern_count != label_pattern_count: + raise forms.ValidationError({ + 'label_pattern': 'The provided name pattern will create {} components, however {} labels will ' + 'be generated. These counts must match.'.format( + name_pattern_count, label_pattern_count) + }, code='label_pattern_mismatch') + + class TableConfigForm(BootstrapMixin, forms.Form): """ Form for configuring user's table preferences. diff --git a/netbox/utilities/mptt.py b/netbox/utilities/mptt.py new file mode 100644 index 000000000..1bae2053d --- /dev/null +++ b/netbox/utilities/mptt.py @@ -0,0 +1,19 @@ +from mptt.managers import TreeManager as TreeManager_ +from mptt.querysets import TreeQuerySet as TreeQuerySet_ + +from django.db.models import Manager +from .querysets import RestrictedQuerySet + + +class TreeQuerySet(TreeQuerySet_, RestrictedQuerySet): + """ + Mate django-mptt's TreeQuerySet with our RestrictedQuerySet for permissions enforcement. + """ + pass + + +class TreeManager(Manager.from_queryset(TreeQuerySet), TreeManager_): + """ + Extend django-mptt's TreeManager to incorporate RestrictedQuerySet(). + """ + pass diff --git a/netbox/utilities/paginator.py b/netbox/utilities/paginator.py index cef7c941f..cdad1f230 100644 --- a/netbox/utilities/paginator.py +++ b/netbox/utilities/paginator.py @@ -50,9 +50,12 @@ def get_paginate_count(request): if 'per_page' in request.GET: try: per_page = int(request.GET.get('per_page')) - request.user.config.set('pagination.per_page', per_page, commit=True) + if request.user.is_authenticated: + request.user.config.set('pagination.per_page', per_page, commit=True) return per_page except ValueError: pass - return request.user.config.get('pagination.per_page', settings.PAGINATE_COUNT) + if request.user.is_authenticated: + return request.user.config.get('pagination.per_page', settings.PAGINATE_COUNT) + return settings.PAGINATE_COUNT diff --git a/netbox/utilities/permissions.py b/netbox/utilities/permissions.py new file mode 100644 index 000000000..44c34942f --- /dev/null +++ b/netbox/utilities/permissions.py @@ -0,0 +1,74 @@ +from django.conf import settings +from django.contrib.contenttypes.models import ContentType + + +def get_permission_for_model(model, action): + """ + Resolve the named permission for a given model (or instance) and action (e.g. view or add). + + :param model: A model or instance + :param action: View, add, change, or delete (string) + """ + if action not in ('view', 'add', 'change', 'delete'): + raise ValueError(f"Unsupported action: {action}") + + return '{}.{}_{}'.format( + model._meta.app_label, + action, + model._meta.model_name + ) + + +def resolve_permission(name): + """ + Given a permission name, return the app_label, action, and model_name components. For example, "dcim.view_site" + returns ("dcim", "view", "site"). + + :param name: Permission name in the format ._ + """ + try: + app_label, codename = name.split('.') + action, model_name = codename.rsplit('_', 1) + except ValueError: + raise ValueError( + f"Invalid permission name: {name}. Must be in the format ._" + ) + + return app_label, action, model_name + + +def resolve_permission_ct(name): + """ + Given a permission name, return the relevant ContentType and action. For example, "dcim.view_site" returns + (Site, "view"). + + :param name: Permission name in the format ._ + """ + app_label, action, model_name = resolve_permission(name) + try: + content_type = ContentType.objects.get(app_label=app_label, model=model_name) + except ContentType.DoesNotExist: + raise ValueError(f"Unknown app_label/model_name for {name}") + + return content_type, action + + +def permission_is_exempt(name): + """ + Determine whether a specified permission is exempt from evaluation. + + :param name: Permission name in the format ._ + """ + app_label, action, model_name = resolve_permission(name) + + if action == 'view': + if ( + # All models are exempt from view permission enforcement + '*' in settings.EXEMPT_VIEW_PERMISSIONS + ) or ( + # This specific model is exempt from view permission enforcement + '{}.{}'.format(app_label, model_name) in settings.EXEMPT_VIEW_PERMISSIONS + ): + return True + + return False diff --git a/netbox/utilities/querysets.py b/netbox/utilities/querysets.py index 34b7a0cf3..04bc4c542 100644 --- a/netbox/utilities/querysets.py +++ b/netbox/utilities/querysets.py @@ -1,3 +1,10 @@ +import logging + +from django.db.models import Q, QuerySet + +from utilities.permissions import permission_is_exempt + + class DummyQuerySet: """ A fake QuerySet that can be used to cache relationships to objects that have been deleted. @@ -5,5 +12,83 @@ class DummyQuerySet: def __init__(self, queryset): self._cache = [obj for obj in queryset.all()] + def __iter__(self): + return iter(self._cache) + def all(self): return self._cache + + +class RestrictedQuerySet(QuerySet): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Initialize the allow_evaluation flag to False. This indicates that the QuerySet has not yet been restricted. + self.allow_evaluation = False + + def _check_restriction(self): + # Raise a warning if the QuerySet is evaluated without first calling restrict() or unrestricted(). + if not getattr(self, 'allow_evaluation', False): + logger = logging.getLogger('netbox.RestrictedQuerySet') + logger.warning( + f'Evaluation of RestrictedQuerySet prior to calling restrict() or unrestricted(): {self.model}' + ) + + def _clone(self): + + # Persist the allow_evaluation flag when cloning the QuerySet. + c = super()._clone() + c.allow_evaluation = self.allow_evaluation + + return c + + def _fetch_all(self): + self._check_restriction() + return super()._fetch_all() + + def count(self): + self._check_restriction() + return super().count() + + def unrestricted(self): + """ + Bypass restriction for the QuerySet. This is necessary in cases where we are not interacting with the objects + directly (e.g. when filtering by related object). + """ + self.allow_evaluation = True + return self + + def restrict(self, user, action): + """ + Filter the QuerySet to return only objects on which the specified user has been granted the specified + permission. + + :param user: User instance + :param action: The action which must be permitted (e.g. "view" for "dcim.view_site") + """ + # Resolve the full name of the required permission + app_label = self.model._meta.app_label + model_name = self.model._meta.model_name + permission_required = f'{app_label}.{action}_{model_name}' + + # Bypass restriction for superusers and exempt views + if user.is_superuser or permission_is_exempt(permission_required): + qs = self + + # User is anonymous or has not been granted the requisite permission + elif not user.is_authenticated or permission_required not in user.get_all_permissions(): + qs = self.none() + + # Filter the queryset to include only objects with allowed attributes + else: + attrs = Q() + for perm_attrs in user._object_perm_cache[permission_required]: + if perm_attrs: + attrs |= Q(**perm_attrs) + qs = self.filter(attrs) + + # Allow QuerySet evaluation + qs.allow_evaluation = True + + return qs diff --git a/netbox/utilities/tables.py b/netbox/utilities/tables.py index 0702936b5..5e277e633 100644 --- a/netbox/utilities/tables.py +++ b/netbox/utilities/tables.py @@ -1,6 +1,5 @@ import django_tables2 as tables from django.core.exceptions import FieldDoesNotExist -from django.db.models import ForeignKey from django.db.models.fields.related import RelatedField from django.utils.safestring import mark_safe from django_tables2.data import TableQuerysetData @@ -9,7 +8,13 @@ from django_tables2.data import TableQuerysetData class BaseTable(tables.Table): """ Default table for object lists + + :param add_prefetch: By default, modify the queryset passed to the table upon initialization to automatically + prefetch related data. Set this to False if it's necessary to avoid modifying the queryset (e.g. to + accommodate PrefixQuerySet.annotate_depth()). """ + add_prefetch = True + class Meta: attrs = { 'class': 'table table-hover table-headings', @@ -50,7 +55,7 @@ class BaseTable(tables.Table): self.sequence.append('actions') # Dynamically update the table's QuerySet to ensure related fields are pre-fetched - if isinstance(self.data, TableQuerysetData): + if self.add_prefetch and isinstance(self.data, TableQuerysetData): model = getattr(self.Meta, 'model') prefetch_fields = [] for column in self.columns: @@ -79,6 +84,10 @@ class BaseTable(tables.Table): return [name for name in self.sequence if self.columns[name].visible] +# +# Table columns +# + class ToggleColumn(tables.CheckBoxColumn): """ Extend CheckBoxColumn to add a "toggle all" checkbox in the column header. @@ -124,12 +133,25 @@ class ColorColumn(tables.Column): ) +class ColoredLabelColumn(tables.TemplateColumn): + """ + Render a colored label (e.g. for DeviceRoles). + """ + template_code = """ + {% load helpers %} + {% if value %}{% else %}—{% endif %} + """ + + def __init__(self, *args, **kwargs): + super().__init__(template_code=self.template_code, *args, **kwargs) + + class TagColumn(tables.TemplateColumn): """ Display a list of tags assigned to the object. """ template_code = """ - {% for tag in value.all %} + {% for tag in value.all.unrestricted %} {% include 'utilities/templatetags/tag.html' %} {% empty %} diff --git a/netbox/utilities/templatetags/buttons.py b/netbox/utilities/templatetags/buttons.py index 85f75f79e..da40ce9d5 100644 --- a/netbox/utilities/templatetags/buttons.py +++ b/netbox/utilities/templatetags/buttons.py @@ -97,7 +97,8 @@ def import_button(url): @register.inclusion_tag('buttons/export.html', takes_context=True) def export_button(context, content_type=None): if content_type is not None: - export_templates = ExportTemplate.objects.filter(content_type=content_type) + user = context['request'].user + export_templates = ExportTemplate.objects.restrict(user, 'view').filter(content_type=content_type) else: export_templates = [] diff --git a/netbox/utilities/templatetags/helpers.py b/netbox/utilities/templatetags/helpers.py index 8a82fc48b..a70e917d8 100644 --- a/netbox/utilities/templatetags/helpers.py +++ b/netbox/utilities/templatetags/helpers.py @@ -10,7 +10,6 @@ from django.utils.html import strip_tags from django.utils.safestring import mark_safe from markdown import markdown -from utilities.choices import unpack_grouped_choices from utilities.utils import foreground_color register = template.Library() @@ -39,6 +38,11 @@ def render_markdown(value): # Strip HTML tags value = strip_tags(value) + # Sanitize Markdown links + schemes = '|'.join(settings.ALLOWED_URL_SCHEMES) + pattern = fr'\[(.+)\]\((?!({schemes})).*:(.+)\)' + value = re.sub(pattern, '[\\1](\\3)', value, flags=re.IGNORECASE) + # Render Markdown html = markdown(value, extensions=['fenced_code', 'tables']) diff --git a/netbox/utilities/testing/__init__.py b/netbox/utilities/testing/__init__.py index 30e452215..1c18a3481 100644 --- a/netbox/utilities/testing/__init__.py +++ b/netbox/utilities/testing/__init__.py @@ -1,2 +1,3 @@ -from .testcases import * +from .api import * from .utils import * +from .views import * diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py new file mode 100644 index 000000000..ce4f1d1e5 --- /dev/null +++ b/netbox/utilities/testing/api.py @@ -0,0 +1,282 @@ +from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType +from django.urls import reverse +from django.test import override_settings +from rest_framework import status +from rest_framework.test import APIClient + +from users.models import ObjectPermission, Token +from .utils import disable_warnings +from .views import TestCase + + +__all__ = ( + 'APITestCase', + 'APIViewTestCases', +) + + +# +# REST API Tests +# + +class APITestCase(TestCase): + client_class = APIClient + model = None + + def setUp(self): + """ + Create a superuser and token for API calls. + """ + # Create the test user and assign permissions + self.user = User.objects.create_user(username='testuser') + self.add_permissions(*self.user_permissions) + self.token = Token.objects.create(user=self.user) + self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} + + def _get_detail_url(self, instance): + viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail' + return reverse(viewname, kwargs={'pk': instance.pk}) + + def _get_list_url(self): + viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list' + return reverse(viewname) + + +class APIViewTestCases: + + class GetObjectViewTestCase(APITestCase): + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_get_object_anonymous(self): + """ + GET a single object as an unauthenticated user. + """ + url = self._get_detail_url(self.model.objects.first()) + response = self.client.get(url, **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_get_object_without_permission(self): + """ + GET a single object as an authenticated user without the required permission. + """ + url = self._get_detail_url(self.model.objects.first()) + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_get_object(self): + """ + GET a single object as an authenticated user with permission to view the object. + """ + self.assertGreaterEqual(self.model.objects.count(), 2, + f"Test requires the creation of at least two {self.model} instances") + instance1, instance2 = self.model.objects.all()[:2] + + # Add object-level permission + obj_perm = ObjectPermission( + constraints={'pk': instance1.pk}, + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET to permitted object + url = self._get_detail_url(instance1) + self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK) + + # Try GET to non-permitted object + url = self._get_detail_url(instance2) + self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_404_NOT_FOUND) + + class ListObjectsViewTestCase(APITestCase): + brief_fields = [] + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_list_objects_anonymous(self): + """ + GET a list of objects as an unauthenticated user. + """ + url = self._get_list_url() + response = self.client.get(url, **self.header) + + self.assertEqual(len(response.data['results']), self.model.objects.count()) + self.assertHttpStatus(response, status.HTTP_200_OK) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_list_objects_brief(self): + """ + GET a list of objects using the "brief" parameter as an unauthenticated user. + """ + url = f'{self._get_list_url()}?brief=1' + response = self.client.get(url, **self.header) + + self.assertEqual(len(response.data['results']), self.model.objects.count()) + self.assertEqual(sorted(response.data['results'][0]), self.brief_fields) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_list_objects_without_permission(self): + """ + GET a list of objects as an authenticated user without the required permission. + """ + url = self._get_list_url() + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_list_objects(self): + """ + GET a list of objects as an authenticated user with permission to view the objects. + """ + self.assertGreaterEqual(self.model.objects.count(), 3, + f"Test requires the creation of at least three {self.model} instances") + instance1, instance2 = self.model.objects.all()[:2] + + # Add object-level permission + obj_perm = ObjectPermission( + constraints={'pk__in': [instance1.pk, instance2.pk]}, + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET to permitted objects + response = self.client.get(self._get_list_url(), **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + self.assertEqual(len(response.data['results']), 2) + + class CreateObjectViewTestCase(APITestCase): + create_data = [] + + def test_create_object_without_permission(self): + """ + POST a single object without permission. + """ + url = self._get_list_url() + + # Try POST without permission + with disable_warnings('django.request'): + response = self.client.post(url, self.create_data[0], format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) + + def test_create_object(self): + """ + POST a single object with permission. + """ + # Add object-level permission + obj_perm = ObjectPermission( + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + initial_count = self.model.objects.count() + response = self.client.post(self._get_list_url(), self.create_data[0], format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(self.model.objects.count(), initial_count + 1) + self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0], api=True) + + def test_bulk_create_objects(self): + """ + POST a set of objects in a single request. + """ + # Add object-level permission + obj_perm = ObjectPermission( + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + initial_count = self.model.objects.count() + response = self.client.post(self._get_list_url(), self.create_data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(len(response.data), len(self.create_data)) + self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data)) + for i, obj in enumerate(response.data): + self.assertInstanceEqual(self.model.objects.get(pk=obj['id']), self.create_data[i], api=True) + + class UpdateObjectViewTestCase(APITestCase): + update_data = {} + + def test_update_object_without_permission(self): + """ + PATCH a single object without permission. + """ + url = self._get_detail_url(self.model.objects.first()) + update_data = self.update_data or getattr(self, 'create_data')[0] + + # Try PATCH without permission + with disable_warnings('django.request'): + response = self.client.patch(url, update_data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) + + def test_update_object(self): + """ + PATCH a single object identified by its numeric ID. + """ + instance = self.model.objects.first() + url = self._get_detail_url(instance) + update_data = self.update_data or getattr(self, 'create_data')[0] + + # Add object-level permission + obj_perm = ObjectPermission( + actions=['change'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + response = self.client.patch(url, update_data, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + instance.refresh_from_db() + self.assertInstanceEqual(instance, self.update_data, api=True) + + class DeleteObjectViewTestCase(APITestCase): + + def test_delete_object_without_permission(self): + """ + DELETE a single object without permission. + """ + url = self._get_detail_url(self.model.objects.first()) + + # Try DELETE without permission + with disable_warnings('django.request'): + response = self.client.delete(url, **self.header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) + + def test_delete_object(self): + """ + DELETE a single object identified by its numeric ID. + """ + instance = self.model.objects.first() + url = self._get_detail_url(instance) + + # Add object-level permission + obj_perm = ObjectPermission( + actions=['delete'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + response = self.client.delete(url, **self.header) + self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) + self.assertFalse(self.model.objects.filter(pk=instance.pk).exists()) + + class APIViewTestCase( + GetObjectViewTestCase, + ListObjectsViewTestCase, + CreateObjectViewTestCase, + UpdateObjectViewTestCase, + DeleteObjectViewTestCase + ): + pass diff --git a/netbox/utilities/testing/testcases.py b/netbox/utilities/testing/testcases.py deleted file mode 100644 index de8b93232..000000000 --- a/netbox/utilities/testing/testcases.py +++ /dev/null @@ -1,476 +0,0 @@ -from django.contrib.auth.models import Permission, User -from django.core.exceptions import ObjectDoesNotExist -from django.forms.models import model_to_dict -from django.test import Client, TestCase as _TestCase, override_settings -from django.urls import reverse, NoReverseMatch -from rest_framework.test import APIClient - -from users.models import Token -from .utils import disable_warnings, post_data - - -class TestCase(_TestCase): - user_permissions = () - - def setUp(self): - - # Create the test user and assign permissions - self.user = User.objects.create_user(username='testuser') - self.add_permissions(*self.user_permissions) - - # Initialize the test client - self.client = Client() - self.client.force_login(self.user) - - # - # Permissions management - # - - def add_permissions(self, *names): - """ - Assign a set of permissions to the test user. Accepts permission names in the form ._. - """ - for name in names: - app, codename = name.split('.') - perm = Permission.objects.get(content_type__app_label=app, codename=codename) - self.user.user_permissions.add(perm) - - def remove_permissions(self, *names): - """ - Remove a set of permissions from the test user, if assigned. - """ - for name in names: - app, codename = name.split('.') - perm = Permission.objects.get(content_type__app_label=app, codename=codename) - self.user.user_permissions.remove(perm) - - # - # Convenience methods - # - - def assertHttpStatus(self, response, expected_status): - """ - TestCase method. Provide more detail in the event of an unexpected HTTP response. - """ - err_message = "Expected HTTP status {}; received {}: {}" - self.assertEqual(response.status_code, expected_status, err_message.format( - expected_status, response.status_code, getattr(response, 'data', 'No data') - )) - - -class ModelViewTestCase(TestCase): - """ - Base TestCase for model views. Subclass to test individual views. - """ - model = None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if self.model is None: - raise Exception("Test case requires model to be defined") - - def _get_base_url(self): - """ - Return the base format for a URL for the test's model. Override this to test for a model which belongs - to a different app (e.g. testing Interfaces within the virtualization app). - """ - return '{}:{}_{{}}'.format( - self.model._meta.app_label, - self.model._meta.model_name - ) - - def _get_url(self, action, instance=None): - """ - Return the URL name for a specific action. An instance must be specified for - get/edit/delete views. - """ - url_format = self._get_base_url() - - if action in ('list', 'add', 'import', 'bulk_edit', 'bulk_delete'): - return reverse(url_format.format(action)) - - elif action in ('get', 'edit', 'delete'): - if instance is None: - raise Exception("Resolving {} URL requires specifying an instance".format(action)) - # Attempt to resolve using slug first - if hasattr(self.model, 'slug'): - try: - return reverse(url_format.format(action), kwargs={'slug': instance.slug}) - except NoReverseMatch: - pass - return reverse(url_format.format(action), kwargs={'pk': instance.pk}) - - else: - raise Exception("Invalid action for URL resolution: {}".format(action)) - - def assertInstanceEqual(self, instance, data): - """ - Compare a model instance to a dictionary, checking that its attribute values match those specified - in the dictionary. - """ - model_dict = model_to_dict(instance, fields=data.keys()) - - for key in list(model_dict.keys()): - - # TODO: Differentiate between tags assigned to the instance and a M2M field for tags (ex: ConfigContext) - if key == 'tags': - model_dict[key] = ','.join(sorted([tag.name for tag in model_dict['tags']])) - - # Convert ManyToManyField to list of instance PKs - elif model_dict[key] and type(model_dict[key]) in (list, tuple) and hasattr(model_dict[key][0], 'pk'): - model_dict[key] = [obj.pk for obj in model_dict[key]] - - # Omit any dictionary keys which are not instance attributes - relevant_data = { - k: v for k, v in data.items() if hasattr(instance, k) - } - - self.assertDictEqual(model_dict, relevant_data) - - -class APITestCase(TestCase): - client_class = APIClient - - def setUp(self): - """ - Create a superuser and token for API calls. - """ - self.user = User.objects.create(username='testuser', is_superuser=True) - self.token = Token.objects.create(user=self.user) - self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} - - -class ViewTestCases: - """ - We keep any TestCases with test_* methods inside a class to prevent unittest from trying to run them. - """ - class GetObjectViewTestCase(ModelViewTestCase): - """ - Retrieve a single instance. - """ - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_get_object(self): - instance = self.model.objects.first() - - # Attempt to make the request without required permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.get(instance.get_absolute_url()), 403) - - # Assign the required permission and submit again - self.add_permissions( - '{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(instance.get_absolute_url()) - self.assertHttpStatus(response, 200) - - class CreateObjectViewTestCase(ModelViewTestCase): - """ - Create a single new instance. - """ - form_data = {} - - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_create_object(self): - - # Try GET without permission - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(self._get_url('add')), 403) - - # Try GET with permission - self.add_permissions( - '{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(path=self._get_url('add')) - self.assertHttpStatus(response, 200) - - # Try POST with permission - initial_count = self.model.objects.count() - request = { - 'path': self._get_url('add'), - 'data': post_data(self.form_data), - 'follow': False, # Do not follow 302 redirects - } - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - # Validate object creation - self.assertEqual(initial_count + 1, self.model.objects.count()) - instance = self.model.objects.order_by('-pk').first() - self.assertInstanceEqual(instance, self.form_data) - - class EditObjectViewTestCase(ModelViewTestCase): - """ - Edit a single existing instance. - """ - form_data = {} - - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_edit_object(self): - instance = self.model.objects.first() - - # Try GET without permission - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(self._get_url('edit', instance)), 403) - - # Try GET with permission - self.add_permissions( - '{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(path=self._get_url('edit', instance)) - self.assertHttpStatus(response, 200) - - # Try POST with permission - request = { - 'path': self._get_url('edit', instance), - 'data': post_data(self.form_data), - 'follow': False, # Do not follow 302 redirects - } - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - # Validate object modifications - instance = self.model.objects.get(pk=instance.pk) - self.assertInstanceEqual(instance, self.form_data) - - class DeleteObjectViewTestCase(ModelViewTestCase): - """ - Delete a single instance. - """ - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_delete_object(self): - instance = self.model.objects.first() - - # Try GET without permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(self._get_url('delete', instance)), 403) - - # Try GET with permission - self.add_permissions( - '{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(path=self._get_url('delete', instance)) - self.assertHttpStatus(response, 200) - - request = { - 'path': self._get_url('delete', instance), - 'data': {'confirm': True}, - 'follow': False, # Do not follow 302 redirects - } - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - # Validate object deletion - with self.assertRaises(ObjectDoesNotExist): - self.model.objects.get(pk=instance.pk) - - class ListObjectsViewTestCase(ModelViewTestCase): - """ - Retrieve multiple instances. - """ - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_list_objects(self): - # Attempt to make the request without required permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.get(self._get_url('list')), 403) - - # Assign the required permission and submit again - self.add_permissions( - '{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(self._get_url('list')) - self.assertHttpStatus(response, 200) - - # Built-in CSV export - if hasattr(self.model, 'csv_headers'): - response = self.client.get('{}?export'.format(self._get_url('list'))) - self.assertHttpStatus(response, 200) - self.assertEqual(response.get('Content-Type'), 'text/csv') - - class BulkCreateObjectsViewTestCase(ModelViewTestCase): - """ - Create multiple instances using a single form. Expects the creation of three new instances by default. - """ - bulk_create_count = 3 - bulk_create_data = {} - - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_bulk_create_objects(self): - initial_count = self.model.objects.count() - request = { - 'path': self._get_url('add'), - 'data': post_data(self.bulk_create_data), - 'follow': False, # Do not follow 302 redirects - } - - # Attempt to make the request without required permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(**request), 403) - - # Assign the required permission and submit again - self.add_permissions( - '{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - self.assertEqual(initial_count + self.bulk_create_count, self.model.objects.count()) - for instance in self.model.objects.order_by('-pk')[:self.bulk_create_count]: - self.assertInstanceEqual(instance, self.bulk_create_data) - - class ImportObjectsViewTestCase(ModelViewTestCase): - """ - Create multiple instances from imported data. - """ - csv_data = () - - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_import_objects(self): - - # Test GET without permission - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.get(self._get_url('import')), 403) - - # Test GET with permission - self.add_permissions( - '{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name), - '{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.get(self._get_url('import')) - self.assertHttpStatus(response, 200) - - # Test POST with permission - initial_count = self.model.objects.count() - request = { - 'path': self._get_url('import'), - 'data': { - 'csv': '\n'.join(self.csv_data) - } - } - response = self.client.post(**request) - self.assertHttpStatus(response, 200) - - # Validate import of new objects - self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1) - - class BulkEditObjectsViewTestCase(ModelViewTestCase): - """ - Edit multiple instances. - """ - bulk_edit_data = {} - - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_bulk_edit_objects(self): - # Bulk edit the first three objects only - pk_list = self.model.objects.values_list('pk', flat=True)[:3] - - request = { - 'path': self._get_url('bulk_edit'), - 'data': { - 'pk': pk_list, - '_apply': True, # Form button - }, - 'follow': False, # Do not follow 302 redirects - } - - # Append the form data to the request - request['data'].update(post_data(self.bulk_edit_data)) - - # Attempt to make the request without required permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(**request), 403) - - # Assign the required permission and submit again - self.add_permissions( - '{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - for i, instance in enumerate(self.model.objects.filter(pk__in=pk_list)): - self.assertInstanceEqual(instance, self.bulk_edit_data) - - class BulkDeleteObjectsViewTestCase(ModelViewTestCase): - """ - Delete multiple instances. - """ - @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) - def test_bulk_delete_objects(self): - pk_list = self.model.objects.values_list('pk', flat=True) - - request = { - 'path': self._get_url('bulk_delete'), - 'data': { - 'pk': pk_list, - 'confirm': True, - '_confirm': True, # Form button - }, - 'follow': False, # Do not follow 302 redirects - } - - # Attempt to make the request without required permissions - with disable_warnings('django.request'): - self.assertHttpStatus(self.client.post(**request), 403) - - # Assign the required permission and submit again - self.add_permissions( - '{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name) - ) - response = self.client.post(**request) - self.assertHttpStatus(response, 302) - - # Check that all objects were deleted - self.assertEqual(self.model.objects.count(), 0) - - class PrimaryObjectViewTestCase( - GetObjectViewTestCase, - CreateObjectViewTestCase, - EditObjectViewTestCase, - DeleteObjectViewTestCase, - ListObjectsViewTestCase, - ImportObjectsViewTestCase, - BulkEditObjectsViewTestCase, - BulkDeleteObjectsViewTestCase, - ): - """ - TestCase suitable for testing all standard View functions for primary objects - """ - maxDiff = None - - class OrganizationalObjectViewTestCase( - CreateObjectViewTestCase, - EditObjectViewTestCase, - ListObjectsViewTestCase, - ImportObjectsViewTestCase, - BulkDeleteObjectsViewTestCase, - ): - """ - TestCase suitable for all organizational objects - """ - maxDiff = None - - class DeviceComponentTemplateViewTestCase( - EditObjectViewTestCase, - DeleteObjectViewTestCase, - BulkCreateObjectsViewTestCase, - BulkEditObjectsViewTestCase, - BulkDeleteObjectsViewTestCase, - ): - """ - TestCase suitable for testing device component template models (ConsolePortTemplates, InterfaceTemplates, etc.) - """ - maxDiff = None - - class DeviceComponentViewTestCase( - EditObjectViewTestCase, - DeleteObjectViewTestCase, - ListObjectsViewTestCase, - BulkCreateObjectsViewTestCase, - ImportObjectsViewTestCase, - BulkEditObjectsViewTestCase, - BulkDeleteObjectsViewTestCase, - ): - """ - TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.) - """ - maxDiff = None diff --git a/netbox/utilities/testing/utils.py b/netbox/utilities/testing/utils.py index fd8c70f05..d763012f0 100644 --- a/netbox/utilities/testing/utils.py +++ b/netbox/utilities/testing/utils.py @@ -14,7 +14,14 @@ def post_data(data): if value is None: ret[key] = '' elif type(value) in (list, tuple): - ret[key] = value + if value and hasattr(value[0], 'pk'): + # Value is a list of instances + ret[key] = [v.pk for v in value] + else: + ret[key] = value + elif hasattr(value, 'pk'): + # Value is an instance + ret[key] = value.pk else: ret[key] = str(value) diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py new file mode 100644 index 000000000..774ceac85 --- /dev/null +++ b/netbox/utilities/testing/views.py @@ -0,0 +1,931 @@ +from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType +from django.contrib.postgres.fields import ArrayField +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist +from django.db.models import ManyToManyField +from django.forms.models import model_to_dict +from django.test import Client, TestCase as _TestCase, override_settings +from django.urls import reverse, NoReverseMatch +from django.utils.text import slugify +from netaddr import IPNetwork +from taggit.managers import TaggableManager + +from extras.models import Tag +from users.models import ObjectPermission +from utilities.permissions import resolve_permission_ct +from .utils import disable_warnings, post_data + + +__all__ = ( + 'TestCase', + 'ModelViewTestCase', + 'ViewTestCases', +) + + +class TestCase(_TestCase): + user_permissions = () + + def setUp(self): + + # Create the test user and assign permissions + self.user = User.objects.create_user(username='testuser') + self.add_permissions(*self.user_permissions) + + # Initialize the test client + self.client = Client() + self.client.force_login(self.user) + + def prepare_instance(self, instance): + """ + Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation + against test data. For example, it can be used to decrypt a Secret's plaintext attribute. + """ + return instance + + def model_to_dict(self, instance, fields, api=False): + """ + Return a dictionary representation of an instance. + """ + # Prepare the instance and call Django's model_to_dict() to extract all fields + model_dict = model_to_dict(self.prepare_instance(instance), fields=fields) + + # Map any additional (non-field) instance attributes that were specified + for attr in fields: + if hasattr(instance, attr) and attr not in model_dict: + model_dict[attr] = getattr(instance, attr) + + for key, value in list(model_dict.items()): + try: + field = instance._meta.get_field(key) + except FieldDoesNotExist: + # Attribute is not a model field + continue + + # Handle ManyToManyFields + if value and type(field) in (ManyToManyField, TaggableManager): + + if field.related_model is ContentType: + model_dict[key] = sorted([f'{ct.app_label}.{ct.model}' for ct in value]) + else: + model_dict[key] = sorted([obj.pk for obj in value]) + + if api: + + # Replace ContentType numeric IDs with . + if type(getattr(instance, key)) is ContentType: + ct = ContentType.objects.get(pk=value) + model_dict[key] = f'{ct.app_label}.{ct.model}' + + # Convert IPNetwork instances to strings + elif type(value) is IPNetwork: + model_dict[key] = str(value) + + else: + + # Convert ArrayFields to CSV strings + if type(instance._meta.get_field(key)) is ArrayField: + model_dict[key] = ','.join([str(v) for v in value]) + + return model_dict + + # + # Permissions management + # + + def add_permissions(self, *names): + """ + Assign a set of permissions to the test user. Accepts permission names in the form ._. + """ + for name in names: + ct, action = resolve_permission_ct(name) + obj_perm = ObjectPermission(actions=[action]) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ct) + + # + # Custom assertions + # + + def assertHttpStatus(self, response, expected_status): + """ + TestCase method. Provide more detail in the event of an unexpected HTTP response. + """ + err_message = "Expected HTTP status {}; received {}: {}" + self.assertEqual(response.status_code, expected_status, err_message.format( + expected_status, response.status_code, getattr(response, 'data', 'No data') + )) + + def assertInstanceEqual(self, instance, data, api=False): + """ + Compare a model instance to a dictionary, checking that its attribute values match those specified + in the dictionary. + + :instance: Python object instance + :data: Dictionary of test data used to define the instance + :api: Set to True is the data is a JSON representation of the instance + """ + model_dict = self.model_to_dict(instance, fields=data.keys(), api=api) + + # Omit any dictionary keys which are not instance attributes + relevant_data = { + k: v for k, v in data.items() if hasattr(instance, k) + } + + self.assertDictEqual(model_dict, relevant_data) + + # + # Convenience methods + # + + @classmethod + def create_tags(cls, *names): + """ + Create and return a Tag instance for each name given. + """ + tags = [Tag(name=name, slug=slugify(name)) for name in names] + Tag.objects.bulk_create(tags) + return tags + + +# +# UI Tests +# + +class ModelViewTestCase(TestCase): + """ + Base TestCase for model views. Subclass to test individual views. + """ + model = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.model is None: + raise Exception("Test case requires model to be defined") + + def _get_base_url(self): + """ + Return the base format for a URL for the test's model. Override this to test for a model which belongs + to a different app (e.g. testing Interfaces within the virtualization app). + """ + return '{}:{}_{{}}'.format( + self.model._meta.app_label, + self.model._meta.model_name + ) + + def _get_url(self, action, instance=None): + """ + Return the URL name for a specific action. An instance must be specified for + get/edit/delete views. + """ + url_format = self._get_base_url() + + if action in ('list', 'add', 'import', 'bulk_edit', 'bulk_delete'): + return reverse(url_format.format(action)) + + elif action in ('get', 'edit', 'delete'): + if instance is None: + raise Exception("Resolving {} URL requires specifying an instance".format(action)) + # Attempt to resolve using slug first + if hasattr(self.model, 'slug'): + try: + return reverse(url_format.format(action), kwargs={'slug': instance.slug}) + except NoReverseMatch: + pass + return reverse(url_format.format(action), kwargs={'pk': instance.pk}) + + else: + raise Exception("Invalid action for URL resolution: {}".format(action)) + + +class ViewTestCases: + """ + We keep any TestCases with test_* methods inside a class to prevent unittest from trying to run them. + """ + class GetObjectViewTestCase(ModelViewTestCase): + """ + Retrieve a single instance. + """ + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_get_object_anonymous(self): + # Make the request as an unauthenticated user + self.client.logout() + response = self.client.get(self.model.objects.first().get_absolute_url()) + self.assertHttpStatus(response, 200) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_get_object_without_permission(self): + instance = self.model.objects.first() + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(instance.get_absolute_url()), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_get_object_with_permission(self): + instance = self.model.objects.first() + + # Add model-level permission + obj_perm = ObjectPermission( + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(instance.get_absolute_url()), 200) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_get_object_with_constrained_permission(self): + instance1, instance2 = self.model.objects.all()[:2] + + # Add object-level permission + obj_perm = ObjectPermission( + constraints={'pk': instance1.pk}, + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET to permitted object + self.assertHttpStatus(self.client.get(instance1.get_absolute_url()), 200) + + # Try GET to non-permitted object + self.assertHttpStatus(self.client.get(instance2.get_absolute_url()), 404) + + class CreateObjectViewTestCase(ModelViewTestCase): + """ + Create a single new instance. + + :form_data: Data to be used when creating a new object. + """ + form_data = {} + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_create_object_without_permission(self): + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(self._get_url('add')), 403) + + # Try POST without permission + request = { + 'path': self._get_url('add'), + 'data': post_data(self.form_data), + } + response = self.client.post(**request) + with disable_warnings('django.request'): + self.assertHttpStatus(response, 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_create_object_with_permission(self): + initial_count = self.model.objects.count() + + # Assign unconstrained permission + obj_perm = ObjectPermission( + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('add')), 200) + + # Try POST with model-level permission + request = { + 'path': self._get_url('add'), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 302) + self.assertEqual(initial_count + 1, self.model.objects.count()) + self.assertInstanceEqual(self.model.objects.order_by('pk').last(), self.form_data) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_create_object_with_constrained_permission(self): + initial_count = self.model.objects.count() + + # Assign constrained permission + obj_perm = ObjectPermission( + constraints={'pk': 0}, # Dummy permission to deny all + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with object-level permission + self.assertHttpStatus(self.client.get(self._get_url('add')), 200) + + # Try to create an object (not permitted) + request = { + 'path': self._get_url('add'), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 200) + self.assertEqual(initial_count, self.model.objects.count()) # Check that no object was created + + # Update the ObjectPermission to allow creation + obj_perm.constraints = {'pk__gt': 0} + obj_perm.save() + + # Try to create an object (permitted) + request = { + 'path': self._get_url('add'), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 302) + self.assertEqual(initial_count + 1, self.model.objects.count()) + self.assertInstanceEqual(self.model.objects.order_by('pk').last(), self.form_data) + + class EditObjectViewTestCase(ModelViewTestCase): + """ + Edit a single existing instance. + + :form_data: Data to be used when updating the first existing object. + """ + form_data = {} + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_edit_object_without_permission(self): + instance = self.model.objects.first() + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(self._get_url('edit', instance)), 403) + + # Try POST without permission + request = { + 'path': self._get_url('edit', instance), + 'data': post_data(self.form_data), + } + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(**request), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_edit_object_with_permission(self): + instance = self.model.objects.first() + + # Assign model-level permission + obj_perm = ObjectPermission( + actions=['change'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('edit', instance)), 200) + + # Try POST with model-level permission + request = { + 'path': self._get_url('edit', instance), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 302) + self.assertInstanceEqual(self.model.objects.get(pk=instance.pk), self.form_data) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_edit_object_with_constrained_permission(self): + instance1, instance2 = self.model.objects.all()[:2] + + # Assign constrained permission + obj_perm = ObjectPermission( + constraints={'pk': instance1.pk}, + actions=['change'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with a permitted object + self.assertHttpStatus(self.client.get(self._get_url('edit', instance1)), 200) + + # Try GET with a non-permitted object + self.assertHttpStatus(self.client.get(self._get_url('edit', instance2)), 404) + + # Try to edit a permitted object + request = { + 'path': self._get_url('edit', instance1), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 302) + self.assertInstanceEqual(self.model.objects.get(pk=instance1.pk), self.form_data) + + # Try to edit a non-permitted object + request = { + 'path': self._get_url('edit', instance2), + 'data': post_data(self.form_data), + } + self.assertHttpStatus(self.client.post(**request), 404) + + class DeleteObjectViewTestCase(ModelViewTestCase): + """ + Delete a single instance. + """ + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_delete_object_without_permission(self): + instance = self.model.objects.first() + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('delete', instance)), 403) + + # Try POST without permission + request = { + 'path': self._get_url('delete', instance), + 'data': post_data({'confirm': True}), + } + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(**request), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_delete_object_with_permission(self): + instance = self.model.objects.first() + + # Assign model-level permission + obj_perm = ObjectPermission( + actions=['delete'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('delete', instance)), 200) + + # Try POST with model-level permission + request = { + 'path': self._get_url('delete', instance), + 'data': post_data({'confirm': True}), + } + self.assertHttpStatus(self.client.post(**request), 302) + with self.assertRaises(ObjectDoesNotExist): + self.model.objects.get(pk=instance.pk) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_delete_object_with_constrained_permission(self): + instance1, instance2 = self.model.objects.all()[:2] + + # Assign object-level permission + obj_perm = ObjectPermission( + constraints={'pk': instance1.pk}, + actions=['delete'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with a permitted object + self.assertHttpStatus(self.client.get(self._get_url('delete', instance1)), 200) + + # Try GET with a non-permitted object + self.assertHttpStatus(self.client.get(self._get_url('delete', instance2)), 404) + + # Try to delete a permitted object + request = { + 'path': self._get_url('delete', instance1), + 'data': post_data({'confirm': True}), + } + self.assertHttpStatus(self.client.post(**request), 302) + with self.assertRaises(ObjectDoesNotExist): + self.model.objects.get(pk=instance1.pk) + + # Try to delete a non-permitted object + request = { + 'path': self._get_url('delete', instance2), + 'data': post_data({'confirm': True}), + } + self.assertHttpStatus(self.client.post(**request), 404) + self.assertTrue(self.model.objects.filter(pk=instance2.pk).exists()) + + class ListObjectsViewTestCase(ModelViewTestCase): + """ + Retrieve multiple instances. + """ + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_list_objects_anonymous(self): + # Make the request as an unauthenticated user + self.client.logout() + response = self.client.get(self._get_url('list')) + self.assertHttpStatus(response, 200) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_list_objects_without_permission(self): + + # Try GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('list')), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_list_objects_with_permission(self): + + # Add model-level permission + obj_perm = ObjectPermission( + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('list')), 200) + + # Built-in CSV export + if hasattr(self.model, 'csv_headers'): + response = self.client.get('{}?export'.format(self._get_url('list'))) + self.assertHttpStatus(response, 200) + self.assertEqual(response.get('Content-Type'), 'text/csv') + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_list_objects_with_constrained_permission(self): + instance1, instance2 = self.model.objects.all()[:2] + + # Add object-level permission + obj_perm = ObjectPermission( + constraints={'pk': instance1.pk}, + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with object-level permission + response = self.client.get(self._get_url('list')) + self.assertHttpStatus(response, 200) + content = str(response.content) + if hasattr(self.model, 'name'): + self.assertIn(instance1.name, content) + self.assertNotIn(instance2.name, content) + else: + self.assertIn(instance1.get_absolute_url(), content) + self.assertNotIn(instance2.get_absolute_url(), content) + + class BulkCreateObjectsViewTestCase(ModelViewTestCase): + """ + Create multiple instances using a single form. Expects the creation of three new instances by default. + + :bulk_create_count: The number of objects expected to be created (default: 3). + :bulk_create_data: A dictionary of data to be used for bulk object creation. + """ + bulk_create_count = 3 + bulk_create_data = {} + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_create_objects_without_permission(self): + request = { + 'path': self._get_url('add'), + 'data': post_data(self.bulk_create_data), + } + + # Try POST without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(**request), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_create_objects_with_permission(self): + initial_count = self.model.objects.count() + request = { + 'path': self._get_url('add'), + 'data': post_data(self.bulk_create_data), + } + + # Assign non-constrained permission + obj_perm = ObjectPermission( + actions=['add'], + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Bulk create objects + response = self.client.post(**request) + self.assertHttpStatus(response, 302) + self.assertEqual(initial_count + self.bulk_create_count, self.model.objects.count()) + for instance in self.model.objects.order_by('-pk')[:self.bulk_create_count]: + self.assertInstanceEqual(instance, self.bulk_create_data) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_create_objects_with_constrained_permission(self): + initial_count = self.model.objects.count() + request = { + 'path': self._get_url('add'), + 'data': post_data(self.bulk_create_data), + } + + # Assign constrained permission + obj_perm = ObjectPermission( + actions=['add'], + constraints={'pk': 0} # Dummy constraint to deny all + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to make the request with unmet constraints + self.assertHttpStatus(self.client.post(**request), 200) + self.assertEqual(self.model.objects.count(), initial_count) + + # Update the ObjectPermission to allow creation + obj_perm.constraints = {'pk__gt': 0} # Dummy constraint to allow all + obj_perm.save() + + response = self.client.post(**request) + self.assertHttpStatus(response, 302) + self.assertEqual(initial_count + self.bulk_create_count, self.model.objects.count()) + for instance in self.model.objects.order_by('-pk')[:self.bulk_create_count]: + self.assertInstanceEqual(instance, self.bulk_create_data) + + class BulkImportObjectsViewTestCase(ModelViewTestCase): + """ + Create multiple instances from imported data. + + :csv_data: A list of CSV-formatted lines (starting with the headers) to be used for bulk object import. + """ + csv_data = () + + def _get_csv_data(self): + return '\n'.join(self.csv_data) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_import_objects_without_permission(self): + data = { + 'csv': self._get_csv_data(), + } + + # Test GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('import')), 403) + + # Try POST without permission + response = self.client.post(self._get_url('import'), data) + with disable_warnings('django.request'): + self.assertHttpStatus(response, 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_import_objects_with_permission(self): + initial_count = self.model.objects.count() + data = { + 'csv': self._get_csv_data(), + } + + # Assign model-level permission + obj_perm = ObjectPermission( + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_import_objects_with_constrained_permission(self): + initial_count = self.model.objects.count() + data = { + 'csv': self._get_csv_data(), + } + + # Assign constrained permission + obj_perm = ObjectPermission( + constraints={'pk': 0}, # Dummy permission to deny all + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to import non-permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(self.model.objects.count(), initial_count) + + # Update permission constraints + obj_perm.constraints = {'pk__gt': 0} # Dummy permission to allow all + obj_perm.save() + + # Import permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1) + + class BulkEditObjectsViewTestCase(ModelViewTestCase): + """ + Edit multiple instances. + + :bulk_edit_data: A dictionary of data to be used when bulk editing a set of objects. This data should differ + from that used for initial object creation within setUpTestData(). + """ + bulk_edit_data = {} + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_edit_objects_without_permission(self): + pk_list = self.model.objects.values_list('pk', flat=True)[:3] + data = { + 'pk': pk_list, + '_apply': True, # Form button + } + + # Test GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('bulk_edit')), 403) + + # Try POST without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(self._get_url('bulk_edit'), data), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_edit_objects_with_permission(self): + pk_list = self.model.objects.values_list('pk', flat=True)[:3] + data = { + 'pk': pk_list, + '_apply': True, # Form button + } + + # Append the form data to the request + data.update(post_data(self.bulk_edit_data)) + + # Assign model-level permission + obj_perm = ObjectPermission( + actions=['change'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try POST with model-level permission + self.assertHttpStatus(self.client.post(self._get_url('bulk_edit'), data), 302) + for i, instance in enumerate(self.model.objects.filter(pk__in=pk_list)): + self.assertInstanceEqual(instance, self.bulk_edit_data) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_edit_objects_with_constrained_permission(self): + initial_instances = self.model.objects.all()[:3] + pk_list = list(self.model.objects.values_list('pk', flat=True)[:3]) + data = { + 'pk': pk_list, + '_apply': True, # Form button + } + + # Append the form data to the request + data.update(post_data(self.bulk_edit_data)) + + # Dynamically determine a constraint that will *not* be matched by the updated objects. + attr_name = list(self.bulk_edit_data.keys())[0] + field = self.model._meta.get_field(attr_name) + value = field.value_from_object(self.model.objects.first()) + + # Assign constrained permission + obj_perm = ObjectPermission( + constraints={attr_name: value}, + actions=['change'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to bulk edit permitted objects into a non-permitted state + response = self.client.post(self._get_url('bulk_edit'), data) + self.assertHttpStatus(response, 200) + + # Update permission constraints + obj_perm.constraints = {'pk__gt': 0} + obj_perm.save() + + # Bulk edit permitted objects + self.assertHttpStatus(self.client.post(self._get_url('bulk_edit'), data), 302) + for i, instance in enumerate(self.model.objects.filter(pk__in=pk_list)): + self.assertInstanceEqual(instance, self.bulk_edit_data) + + class BulkDeleteObjectsViewTestCase(ModelViewTestCase): + """ + Delete multiple instances. + """ + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_delete_objects_without_permission(self): + pk_list = self.model.objects.values_list('pk', flat=True)[:3] + data = { + 'pk': pk_list, + 'confirm': True, + '_confirm': True, # Form button + } + + # Test GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('bulk_delete')), 403) + + # Try POST without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.post(self._get_url('bulk_delete'), data), 403) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_delete_objects_with_permission(self): + pk_list = self.model.objects.values_list('pk', flat=True) + data = { + 'pk': pk_list, + 'confirm': True, + '_confirm': True, # Form button + } + + # Assign unconstrained permission + obj_perm = ObjectPermission( + actions=['delete'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try POST with model-level permission + self.assertHttpStatus(self.client.post(self._get_url('bulk_delete'), data), 302) + self.assertEqual(self.model.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=[]) + def test_bulk_delete_objects_with_constrained_permission(self): + initial_count = self.model.objects.count() + pk_list = self.model.objects.values_list('pk', flat=True) + data = { + 'pk': pk_list, + 'confirm': True, + '_confirm': True, # Form button + } + + # Assign constrained permission + obj_perm = ObjectPermission( + constraints={'pk': 0}, # Dummy permission to deny all + actions=['delete'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to bulk delete non-permitted objects + self.assertHttpStatus(self.client.post(self._get_url('bulk_delete'), data), 302) + self.assertEqual(self.model.objects.count(), initial_count) + + # Update permission constraints + obj_perm.constraints = {'pk__gt': 0} # Dummy permission to allow all + obj_perm.save() + + # Bulk delete permitted objects + self.assertHttpStatus(self.client.post(self._get_url('bulk_delete'), data), 302) + self.assertEqual(self.model.objects.count(), 0) + + class PrimaryObjectViewTestCase( + GetObjectViewTestCase, + CreateObjectViewTestCase, + EditObjectViewTestCase, + DeleteObjectViewTestCase, + ListObjectsViewTestCase, + BulkImportObjectsViewTestCase, + BulkEditObjectsViewTestCase, + BulkDeleteObjectsViewTestCase, + ): + """ + TestCase suitable for testing all standard View functions for primary objects + """ + maxDiff = None + + class OrganizationalObjectViewTestCase( + CreateObjectViewTestCase, + EditObjectViewTestCase, + ListObjectsViewTestCase, + BulkImportObjectsViewTestCase, + BulkDeleteObjectsViewTestCase, + ): + """ + TestCase suitable for all organizational objects + """ + maxDiff = None + + class DeviceComponentTemplateViewTestCase( + EditObjectViewTestCase, + DeleteObjectViewTestCase, + BulkCreateObjectsViewTestCase, + BulkEditObjectsViewTestCase, + BulkDeleteObjectsViewTestCase, + ): + """ + TestCase suitable for testing device component template models (ConsolePortTemplates, InterfaceTemplates, etc.) + """ + maxDiff = None + + class DeviceComponentViewTestCase( + EditObjectViewTestCase, + DeleteObjectViewTestCase, + ListObjectsViewTestCase, + BulkCreateObjectsViewTestCase, + BulkImportObjectsViewTestCase, + BulkEditObjectsViewTestCase, + BulkDeleteObjectsViewTestCase, + ): + """ + TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.) + """ + maxDiff = None diff --git a/netbox/utilities/tests/test_api.py b/netbox/utilities/tests/test_api.py index 469bb3150..01d4ab8f3 100644 --- a/netbox/utilities/tests/test_api.py +++ b/netbox/utilities/tests/test_api.py @@ -18,7 +18,6 @@ class WritableNestedSerializerTest(APITestCase): """ def setUp(self): - super().setUp() self.region_a = Region.objects.create(name='Region A', slug='region-a') @@ -26,39 +25,36 @@ class WritableNestedSerializerTest(APITestCase): self.site2 = Site.objects.create(region=self.region_a, name='Site 2', slug='site-2') def test_related_by_pk(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', 'site': self.site1.pk, } - url = reverse('ipam-api:vlan-list') - response = self.client.post(url, data, format='json', **self.header) + self.add_permissions('ipam.add_vlan') + response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(response.data['site']['id'], self.site1.pk) vlan = VLAN.objects.get(pk=response.data['id']) self.assertEqual(vlan.site, self.site1) def test_related_by_pk_no_match(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', 'site': 999, } - url = reverse('ipam-api:vlan-list') + self.add_permissions('ipam.add_vlan') + with disable_warnings('django.request'): response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertEqual(VLAN.objects.count(), 0) self.assertTrue(response.data['site'][0].startswith("Related object not found")) def test_related_by_attributes(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', @@ -66,17 +62,16 @@ class WritableNestedSerializerTest(APITestCase): 'name': 'Site 1' }, } - url = reverse('ipam-api:vlan-list') - response = self.client.post(url, data, format='json', **self.header) + self.add_permissions('ipam.add_vlan') + response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(response.data['site']['id'], self.site1.pk) vlan = VLAN.objects.get(pk=response.data['id']) self.assertEqual(vlan.site, self.site1) def test_related_by_attributes_no_match(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', @@ -84,17 +79,16 @@ class WritableNestedSerializerTest(APITestCase): 'name': 'Site X' }, } - url = reverse('ipam-api:vlan-list') + self.add_permissions('ipam.add_vlan') + with disable_warnings('django.request'): response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertEqual(VLAN.objects.count(), 0) self.assertTrue(response.data['site'][0].startswith("Related object not found")) def test_related_by_attributes_multiple_matches(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', @@ -104,27 +98,26 @@ class WritableNestedSerializerTest(APITestCase): }, }, } - url = reverse('ipam-api:vlan-list') + self.add_permissions('ipam.add_vlan') + with disable_warnings('django.request'): response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertEqual(VLAN.objects.count(), 0) self.assertTrue(response.data['site'][0].startswith("Multiple objects match")) def test_related_by_invalid(self): - data = { 'vid': 100, 'name': 'Test VLAN 100', 'site': 'XXX', } - url = reverse('ipam-api:vlan-list') + self.add_permissions('ipam.add_vlan') + with disable_warnings('django.request'): response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertEqual(VLAN.objects.count(), 0) diff --git a/netbox/utilities/utils.py b/netbox/utilities/utils.py index 351b1fd68..4c07f5520 100644 --- a/netbox/utilities/utils.py +++ b/netbox/utilities/utils.py @@ -213,9 +213,9 @@ def prepare_cloned_fields(instance): if field_value not in (None, ''): params[field_name] = field_value - # Copy tags - if is_taggable(instance): - params['tags'] = ','.join([t.name for t in instance.tags.all()]) + # Copy tags + if is_taggable(instance): + params['tags'] = ','.join([t.name for t in instance.tags.all()]) # Concatenate parameters into a URL query string param_string = '&'.join( diff --git a/netbox/utilities/validators.py b/netbox/utilities/validators.py index 3b08733cd..517a567a9 100644 --- a/netbox/utilities/validators.py +++ b/netbox/utilities/validators.py @@ -1,31 +1,24 @@ import re +from django.conf import settings from django.core.validators import _lazy_re_compile, BaseValidator, URLValidator class EnhancedURLValidator(URLValidator): """ - Extends Django's built-in URLValidator to permit the use of hostnames with no domain extension. + Extends Django's built-in URLValidator to permit the use of hostnames with no domain extension and enforce allowed + schemes specified in the configuration. """ - class AnyURLScheme(object): - """ - A fake URL list which "contains" all scheme names abiding by the syntax defined in RFC 3986 section 3.1 - """ - def __contains__(self, item): - if not item or not re.match(r'^[a-z][0-9a-z+\-.]*$', item.lower()): - return False - return True - fqdn_re = URLValidator.hostname_re + URLValidator.domain_re + URLValidator.tld_re host_res = [URLValidator.ipv4_re, URLValidator.ipv6_re, fqdn_re, URLValidator.hostname_re] regex = _lazy_re_compile( - r'^(?:[a-z0-9\.\-\+]*)://' # Scheme (previously enforced by AnyURLScheme or schemes kwarg) + r'^(?:[a-z0-9\.\-\+]*)://' # Scheme (enforced separately) r'(?:\S+(?::\S*)?@)?' # HTTP basic authentication r'(?:' + '|'.join(host_res) + ')' # IPv4, IPv6, FQDN, or hostname r'(?::\d{2,5})?' # Port number r'(?:[/?#][^\s]*)?' # Path r'\Z', re.IGNORECASE) - schemes = AnyURLScheme() + schemes = settings.ALLOWED_URL_SCHEMES class ExclusionValidator(BaseValidator): diff --git a/netbox/utilities/views.py b/netbox/utilities/views.py index 076f2ad14..cf282a8c0 100644 --- a/netbox/utilities/views.py +++ b/netbox/utilities/views.py @@ -1,10 +1,13 @@ import logging +import re import sys from copy import deepcopy from django.contrib import messages +from django.contrib.auth.decorators import login_required from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.contrib.auth.mixins import AccessMixin +from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured, ObjectDoesNotExist, ValidationError from django.db import transaction, IntegrityError from django.db.models import ManyToManyField, ProtectedError from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea @@ -13,6 +16,7 @@ from django.shortcuts import get_object_or_404, redirect, render from django.template import loader from django.template.exceptions import TemplateDoesNotExist from django.urls import reverse +from django.utils.decorators import method_decorator from django.utils.html import escape from django.utils.http import is_safe_url from django.utils.safestring import mark_safe @@ -25,12 +29,63 @@ from extras.models import CustomField, CustomFieldValue, ExportTemplate from extras.querysets import CustomFieldQueryset from utilities.exceptions import AbortTransaction from utilities.forms import BootstrapMixin, CSVDataField, TableConfigForm +from utilities.permissions import get_permission_for_model, resolve_permission from utilities.utils import csv_format, prepare_cloned_fields from .error_handlers import handle_protectederror from .forms import ConfirmationForm, ImportForm from .paginator import EnhancedPaginator, get_paginate_count +# +# Mixins +# + +class ObjectPermissionRequiredMixin(AccessMixin): + """ + Similar to Django's built-in PermissionRequiredMixin, but extended to check for both model-level and object-level + permission assignments. If the user has only object-level permissions assigned, the view's queryset is filtered + to return only those objects on which the user is permitted to perform the specified action. + + additional_permissions: An optional iterable of statically declared permissions to evaluate in addition to those + derived from the object type + """ + additional_permissions = list() + + def get_required_permission(self): + """ + Return the specific permission necessary to perform the requested action on an object. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement get_required_permission()") + + def has_permission(self): + user = self.request.user + permission_required = self.get_required_permission() + + # Check that the user has been granted the required permission(s). + if user.has_perms((permission_required, *self.additional_permissions)): + + # Update the view's QuerySet to filter only the permitted objects + action = resolve_permission(permission_required)[1] + self.queryset = self.queryset.restrict(user, action) + + return True + + return False + + def dispatch(self, request, *args, **kwargs): + + if not hasattr(self, 'queryset'): + raise ImproperlyConfigured( + '{} has no queryset defined. ObjectPermissionRequiredMixin may only be used on views which define ' + 'a base queryset'.format(self.__class__.__name__) + ) + + if not self.has_permission(): + return self.handle_no_permission() + + return super().dispatch(request, *args, **kwargs) + + class GetReturnURLMixin(object): """ Provides logic for determining where a user should be redirected after processing a form. @@ -57,7 +112,23 @@ class GetReturnURLMixin(object): return reverse('home') -class ObjectListView(View): +# +# Generic views +# + +class ObjectView(ObjectPermissionRequiredMixin, View): + """ + Retrieve a single object for display. + + queryset: The base queryset for retrieving the object. + """ + queryset = None + + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'view') + + +class ObjectListView(ObjectPermissionRequiredMixin, View): """ List a series of objects. @@ -74,6 +145,9 @@ class ObjectListView(View): template_name = 'utilities/obj_list.html' action_buttons = ('add', 'import', 'export') + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'view') + def queryset_to_yaml(self): """ Export the queryset of objects as concatenated YAML documents. @@ -160,11 +234,14 @@ class ObjectListView(View): # Compile a dictionary indicating which permissions are available to the current user for this model permissions = {} for action in ('add', 'change', 'delete', 'view'): - perm_name = '{}.{}_{}'.format(model._meta.app_label, action, model._meta.model_name) + perm_name = get_permission_for_model(model, action) permissions[action] = request.user.has_perm(perm_name) # Construct the table based on the user's permissions - columns = request.user.config.get(f"tables.{self.table.__name__}.columns") + if request.user.is_authenticated: + columns = request.user.config.get(f"tables.{self.table.__name__}.columns") + else: + columns = None table = self.table(self.queryset, columns=columns) if 'pk' in table.base_columns and (permissions['change'] or permissions['delete']): table.columns.show('pk') @@ -188,6 +265,7 @@ class ObjectListView(View): return render(request, self.template_name, context) + @method_decorator(login_required) def post(self, request): # Update the user's table configuration @@ -212,7 +290,7 @@ class ObjectListView(View): return {} -class ObjectEditView(GetReturnURLMixin, View): +class ObjectEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Create or edit a single object. @@ -224,6 +302,11 @@ class ObjectEditView(GetReturnURLMixin, View): model_form = None template_name = 'utilities/obj_edit.html' + def get_required_permission(self): + # self._permission_action is set by dispatch() to either "add" or "change" depending on whether + # we are modifying an existing object or creating a new one. + return get_permission_for_model(self.queryset.model, self._permission_action) + def get_object(self, kwargs): # Look up an existing object by slug or PK, if provided. if 'slug' in kwargs: @@ -239,68 +322,87 @@ class ObjectEditView(GetReturnURLMixin, View): return obj def dispatch(self, request, *args, **kwargs): - self.obj = self.alter_obj(self.get_object(kwargs), request, args, kwargs) + # Determine required permission based on whether we are editing an existing object + self._permission_action = 'change' if kwargs else 'add' return super().dispatch(request, *args, **kwargs) def get(self, request, *args, **kwargs): + obj = self.alter_obj(self.get_object(kwargs), request, args, kwargs) + # Parse initial data manually to avoid setting field values as lists initial_data = {k: request.GET[k] for k in request.GET} - form = self.model_form(instance=self.obj, initial=initial_data) + form = self.model_form(instance=obj, initial=initial_data) return render(request, self.template_name, { - 'obj': self.obj, + 'obj': obj, 'obj_type': self.queryset.model._meta.verbose_name, 'form': form, - 'return_url': self.get_return_url(request, self.obj), + 'return_url': self.get_return_url(request, obj), }) def post(self, request, *args, **kwargs): logger = logging.getLogger('netbox.views.ObjectEditView') - form = self.model_form(request.POST, request.FILES, instance=self.obj) + obj = self.alter_obj(self.get_object(kwargs), request, args, kwargs) + form = self.model_form( + data=request.POST, + files=request.FILES, + instance=obj + ) if form.is_valid(): logger.debug("Form validation was successful") - obj = form.save() - msg = '{} {}'.format( - 'Created' if not form.instance.pk else 'Modified', - self.queryset.model._meta.verbose_name - ) - logger.info(f"{msg} {obj} (PK: {obj.pk})") - if hasattr(obj, 'get_absolute_url'): - msg = '{} {}'.format(msg, obj.get_absolute_url(), escape(obj)) - else: - msg = '{} {}'.format(msg, escape(obj)) - messages.success(request, mark_safe(msg)) + try: + with transaction.atomic(): + obj = form.save() - if '_addanother' in request.POST: + # Check that the new object conforms with any assigned object-level permissions + self.queryset.get(pk=obj.pk) - # If the object has clone_fields, pre-populate a new instance of the form - if hasattr(obj, 'clone_fields'): - url = '{}?{}'.format(request.path, prepare_cloned_fields(obj)) - return redirect(url) + msg = '{} {}'.format( + 'Created' if not form.instance.pk else 'Modified', + self.queryset.model._meta.verbose_name + ) + logger.info(f"{msg} {obj} (PK: {obj.pk})") + if hasattr(obj, 'get_absolute_url'): + msg = '{} {}'.format(msg, obj.get_absolute_url(), escape(obj)) + else: + msg = '{} {}'.format(msg, escape(obj)) + messages.success(request, mark_safe(msg)) - return redirect(request.get_full_path()) + if '_addanother' in request.POST: - return_url = form.cleaned_data.get('return_url') - if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()): - return redirect(return_url) - else: - return redirect(self.get_return_url(request, obj)) + # If the object has clone_fields, pre-populate a new instance of the form + if hasattr(obj, 'clone_fields'): + url = '{}?{}'.format(request.path, prepare_cloned_fields(obj)) + return redirect(url) + + return redirect(request.get_full_path()) + + return_url = form.cleaned_data.get('return_url') + if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()): + return redirect(return_url) + else: + return redirect(self.get_return_url(request, obj)) + + except ObjectDoesNotExist: + msg = "Object save failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) else: logger.debug("Form validation failed") return render(request, self.template_name, { - 'obj': self.obj, + 'obj': obj, 'obj_type': self.queryset.model._meta.verbose_name, 'form': form, - 'return_url': self.get_return_url(request, self.obj), + 'return_url': self.get_return_url(request, obj), }) -class ObjectDeleteView(GetReturnURLMixin, View): +class ObjectDeleteView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Delete a single object. @@ -310,6 +412,9 @@ class ObjectDeleteView(GetReturnURLMixin, View): queryset = None template_name = 'utilities/obj_delete.html' + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'delete') + def get_object(self, kwargs): # Look up object by slug if one has been provided. Otherwise, use PK. if 'slug' in kwargs: @@ -364,20 +469,25 @@ class ObjectDeleteView(GetReturnURLMixin, View): }) -class BulkCreateView(GetReturnURLMixin, View): +class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Create new objects in bulk. + queryset: Base queryset for the objects being created form: Form class which provides the `pattern` field model_form: The ModelForm used to create individual objects pattern_target: Name of the field to be evaluated as a pattern (if any) template_name: The name of the template """ + queryset = None form = None model_form = None pattern_target = '' template_name = None + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') + def get(self, request): # Set initial values for visible form fields from query args initial = {} @@ -397,7 +507,7 @@ class BulkCreateView(GetReturnURLMixin, View): def post(self, request): logger = logging.getLogger('netbox.views.BulkCreateView') - model = self.model_form._meta.model + model = self.queryset.model form = self.form(request.POST) model_form = self.model_form(request.POST) @@ -430,6 +540,10 @@ class BulkCreateView(GetReturnURLMixin, View): # Raise an IntegrityError to break the for loop and abort the transaction. raise IntegrityError() + # Enforce object-level permissions + if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): + raise ObjectDoesNotExist + # If we make it to this point, validation has succeeded on all new objects. msg = "Added {} {}".format(len(new_objs), model._meta.verbose_name_plural) logger.info(msg) @@ -442,6 +556,11 @@ class BulkCreateView(GetReturnURLMixin, View): except IntegrityError: pass + except ObjectDoesNotExist: + msg = "Object creation failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) + else: logger.debug("Form validation failed") @@ -453,21 +572,29 @@ class BulkCreateView(GetReturnURLMixin, View): }) -class ObjectImportView(GetReturnURLMixin, View): +class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Import a single object (YAML or JSON format). + + queryset: Base queryset for the objects being created + model_form: The ModelForm used to create individual objects + related_object_forms: A dictionary mapping of forms to be used for the creation of related (child) objects + template_name: The name of the template """ - model = None + queryset = None model_form = None related_object_forms = dict() template_name = 'utilities/obj_import.html' + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') + def get(self, request): form = ImportForm() return render(request, self.template_name, { 'form': form, - 'obj_type': self.model._meta.verbose_name, + 'obj_type': self.queryset.model._meta.verbose_name, 'return_url': self.get_return_url(request), }) @@ -497,12 +624,17 @@ class ObjectImportView(GetReturnURLMixin, View): # Save the primary object obj = model_form.save() + + # Enforce object-level permissions + self.queryset.get(pk=obj.pk) + logger.debug(f"Created {obj} (PK: {obj.pk})") # Iterate through the related object forms (if any), validating and saving each instance. for field_name, related_object_form in self.related_object_forms.items(): logger.debug("Processing form for related objects: {related_object_form}") + related_obj_pks = [] for i, rel_obj_data in enumerate(data.get(field_name, list())): f = related_object_form(obj, rel_obj_data) @@ -512,7 +644,8 @@ class ObjectImportView(GetReturnURLMixin, View): f.data[subfield_name] = field.initial if f.is_valid(): - f.save() + related_obj = f.save() + related_obj_pks.append(related_obj.pk) else: # Replicate errors on the related object form to the primary form for display for subfield_name, errors in f.errors.items(): @@ -521,9 +654,19 @@ class ObjectImportView(GetReturnURLMixin, View): model_form.add_error(None, err_msg) raise AbortTransaction() + # Enforce object-level permissions on related objects + model = related_object_form.Meta.model + if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): + raise ObjectDoesNotExist + except AbortTransaction: pass + except ObjectDoesNotExist: + msg = "Object creation failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) + if not model_form.errors: logger.info(f"Import object {obj} (PK: {obj.pk})") messages.success(request, mark_safe('Imported object: {}'.format( @@ -555,20 +698,22 @@ class ObjectImportView(GetReturnURLMixin, View): return render(request, self.template_name, { 'form': form, - 'obj_type': self.model._meta.verbose_name, + 'obj_type': self.queryset.model._meta.verbose_name, 'return_url': self.get_return_url(request), }) -class BulkImportView(GetReturnURLMixin, View): +class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Import objects in bulk (CSV format). + queryset: Base queryset for the model model_form: The form used to create each imported object table: The django-tables2 Table used to render the list of imported objects template_name: The name of the template widget_attrs: A dict of attributes to apply to the import widget (e.g. to require a session key) """ + queryset = None model_form = None table = None template_name = 'utilities/obj_bulk_import.html' @@ -590,6 +735,9 @@ class BulkImportView(GetReturnURLMixin, View): """ return obj_form.save() + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') + def get(self, request): return render(request, self.template_name, { @@ -622,6 +770,10 @@ class BulkImportView(GetReturnURLMixin, View): form.add_error('csv', "Row {} {}: {}".format(row, field, err[0])) raise ValidationError("") + # Enforce object-level permissions + if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): + raise ObjectDoesNotExist + # Compile a table containing the imported objects obj_table = self.table(new_objs) @@ -638,6 +790,11 @@ class BulkImportView(GetReturnURLMixin, View): except ValidationError: pass + except ObjectDoesNotExist: + msg = "Object import failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) + else: logger.debug("Form validation failed") @@ -649,7 +806,7 @@ class BulkImportView(GetReturnURLMixin, View): }) -class BulkEditView(GetReturnURLMixin, View): +class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Edit objects in bulk. @@ -665,6 +822,9 @@ class BulkEditView(GetReturnURLMixin, View): form = None template_name = 'utilities/obj_bulk_edit.html' + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'change') + def get(self, request): return redirect(self.get_return_url(request)) @@ -675,7 +835,7 @@ class BulkEditView(GetReturnURLMixin, View): # If we are editing *all* objects in the queryset, replace the PK list with all matched objects. if request.POST.get('_all') and self.filterset is not None: pk_list = [ - obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs + obj.pk for obj in self.filterset(request.GET, self.queryset.only('pk')).qs ] else: pk_list = request.POST.getlist('pk') @@ -695,8 +855,8 @@ class BulkEditView(GetReturnURLMixin, View): with transaction.atomic(): - updated_count = 0 - for obj in model.objects.filter(pk__in=form.cleaned_data['pk']): + updated_objects = [] + for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']): # Update standard fields. If a field is listed in _nullify, delete its value. for name in standard_fields: @@ -724,6 +884,7 @@ class BulkEditView(GetReturnURLMixin, View): obj.full_clean() obj.save() + updated_objects.append(obj) logger.debug(f"Saved {obj} (PK: {obj.pk})") # Update custom fields @@ -753,10 +914,12 @@ class BulkEditView(GetReturnURLMixin, View): if form.cleaned_data.get('remove_tags', None): obj.tags.remove(*form.cleaned_data['remove_tags']) - updated_count += 1 + # Enforce object-level permissions + if self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count() != len(updated_objects): + raise ObjectDoesNotExist - if updated_count: - msg = 'Updated {} {}'.format(updated_count, model._meta.verbose_name_plural) + if updated_objects: + msg = 'Updated {} {}'.format(len(updated_objects), model._meta.verbose_name_plural) logger.info(msg) messages.success(self.request, msg) @@ -765,6 +928,11 @@ class BulkEditView(GetReturnURLMixin, View): except ValidationError as e: messages.error(self.request, "{} failed validation: {}".format(obj, e)) + except ObjectDoesNotExist: + msg = "Object update failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) + else: logger.debug("Form validation failed") @@ -777,6 +945,8 @@ class BulkEditView(GetReturnURLMixin, View): # TODO: Find a better way to accomplish this if 'device' in request.GET: initial_data['device'] = request.GET.get('device') + elif 'device_type' in request.GET: + initial_data['device_type'] = request.GET.get('device_type') form = self.form(model, initial=initial_data) @@ -794,7 +964,59 @@ class BulkEditView(GetReturnURLMixin, View): }) -class BulkDeleteView(GetReturnURLMixin, View): +class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): + """ + An extendable view for renaming objects in bulk. + """ + queryset = None + form = None + template_name = 'utilities/obj_bulk_rename.html' + + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'change') + + def post(self, request): + + if '_preview' in request.POST or '_apply' in request.POST: + form = self.form(request.POST, initial={'pk': request.POST.getlist('pk')}) + selected_objects = self.queryset.filter(pk__in=form.initial['pk']) + + if form.is_valid(): + for obj in selected_objects: + find = form.cleaned_data['find'] + replace = form.cleaned_data['replace'] + if form.cleaned_data['use_regex']: + try: + obj.new_name = re.sub(find, replace, obj.name) + # Catch regex group reference errors + except re.error: + obj.new_name = obj.name + else: + obj.new_name = obj.name.replace(find, replace) + + if '_apply' in request.POST: + for obj in selected_objects: + obj.name = obj.new_name + obj.save() + messages.success(request, "Renamed {} {}".format( + len(selected_objects), + self.queryset.model._meta.verbose_name_plural + )) + return redirect(self.get_return_url(request)) + + else: + form = self.form(initial={'pk': request.POST.getlist('pk')}) + selected_objects = self.queryset.filter(pk__in=form.initial['pk']) + + return render(request, self.template_name, { + 'form': form, + 'obj_type_plural': self.queryset.model._meta.verbose_name_plural, + 'selected_objects': selected_objects, + 'return_url': self.get_return_url(request), + }) + + +class BulkDeleteView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Delete objects in bulk. @@ -810,6 +1032,9 @@ class BulkDeleteView(GetReturnURLMixin, View): form = None template_name = 'utilities/obj_bulk_delete.html' + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'delete') + def get(self, request): return redirect(self.get_return_url(request)) @@ -834,7 +1059,7 @@ class BulkDeleteView(GetReturnURLMixin, View): logger.debug("Form validation was successful") # Delete objects - queryset = model.objects.filter(pk__in=pk_list) + queryset = self.queryset.filter(pk__in=pk_list) try: deleted_count = queryset.delete()[1][model._meta.label] except ProtectedError as e: @@ -887,37 +1112,44 @@ class BulkDeleteView(GetReturnURLMixin, View): # # TODO: Replace with BulkCreateView -class ComponentCreateView(GetReturnURLMixin, View): +class ComponentCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Add one or more components (e.g. interfaces, console ports, etc.) to a Device or VirtualMachine. """ - model = None + queryset = None form = None model_form = None template_name = None + def get_required_permission(self): + return get_permission_for_model(self.queryset.model, 'add') + def get(self, request): form = self.form(initial=request.GET) return render(request, self.template_name, { - 'component_type': self.model._meta.verbose_name, + 'component_type': self.queryset.model._meta.verbose_name, 'form': form, 'return_url': self.get_return_url(request), }) def post(self, request): - + logger = logging.getLogger('netbox.views.ComponentCreateView') form = self.form(request.POST, initial=request.GET) + if form.is_valid(): new_components = [] data = deepcopy(request.POST) - for i, name in enumerate(form.cleaned_data['name_pattern']): - + names = form.cleaned_data['name_pattern'] + labels = form.cleaned_data.get('label_pattern') + for i, name in enumerate(names): + label = labels[i] if labels else None # Initialize the individual component form data['name'] = name + data['label'] = label if hasattr(form, 'get_iterative_data'): data.update(form.get_iterative_data(i)) component_form = self.model_form(data) @@ -926,50 +1158,74 @@ class ComponentCreateView(GetReturnURLMixin, View): new_components.append(component_form) else: for field, errors in component_form.errors.as_data().items(): - # Assign errors on the child form's name field to name_pattern on the parent form + # Assign errors on the child form's name/label field to name_pattern/label_pattern on the parent form if field == 'name': field = 'name_pattern' + elif field == 'label': + field = 'label_pattern' for e in errors: form.add_error(field, '{}: {}'.format(name, ', '.join(e))) if not form.errors: - # Create the new components - for component_form in new_components: - component_form.save() + try: - messages.success(request, "Added {} {}".format( - len(new_components), self.model._meta.verbose_name_plural - )) - if '_addanother' in request.POST: - return redirect(request.get_full_path()) - else: - return redirect(self.get_return_url(request)) + with transaction.atomic(): + + # Create the new components + new_objs = [] + for component_form in new_components: + obj = component_form.save() + new_objs.append(obj) + + # Enforce object-level permissions + if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): + raise ObjectDoesNotExist + + messages.success(request, "Added {} {}".format( + len(new_components), self.queryset.model._meta.verbose_name_plural + )) + if '_addanother' in request.POST: + return redirect(request.get_full_path()) + elif 'device_type' in form.cleaned_data: + return redirect(form.cleaned_data['device_type'].get_absolute_url()) + elif 'device' in form.cleaned_data: + return redirect(form.cleaned_data['device'].get_absolute_url()) + else: + return redirect(self.get_return_url(request)) + + except ObjectDoesNotExist: + msg = "Component creation failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) return render(request, self.template_name, { - 'component_type': self.model._meta.verbose_name, + 'component_type': self.queryset.model._meta.verbose_name, 'form': form, 'return_url': self.get_return_url(request), }) -class BulkComponentCreateView(GetReturnURLMixin, View): +class BulkComponentCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ Add one or more components (e.g. interfaces, console ports, etc.) to a set of Devices or VirtualMachines. """ parent_model = None parent_field = None form = None - model = None + queryset = None model_form = None filterset = None table = None template_name = 'utilities/obj_bulk_add_component.html' + def get_required_permission(self): + return f'dcim.add_{self.queryset.model._meta.model_name}' + def post(self, request): logger = logging.getLogger('netbox.views.BulkComponentCreateView') parent_model_name = self.parent_model._meta.verbose_name_plural - model_name = self.model._meta.verbose_name_plural + model_name = self.queryset.model._meta.verbose_name_plural # Are we editing *all* objects in the queryset or just a selected subset? if request.POST.get('_all') and self.filterset is not None: @@ -998,10 +1254,14 @@ class BulkComponentCreateView(GetReturnURLMixin, View): for obj in data['pk']: names = data['name_pattern'] - for name in names: + labels = data['label_pattern'] + for i, name in enumerate(names): + label = labels[i] if labels else None + component_data = { self.parent_field: obj.pk, 'name': name, + 'label': label } component_data.update(data) component_form = self.model_form(component_data) @@ -1014,9 +1274,18 @@ class BulkComponentCreateView(GetReturnURLMixin, View): for e in errors: form.add_error(field, '{} {}: {}'.format(obj, name, ', '.join(e))) + # Enforce object-level permissions + if self.queryset.filter(pk__in=[obj.pk for obj in new_components]).count() != len(new_components): + raise ObjectDoesNotExist + except IntegrityError: pass + except ObjectDoesNotExist: + msg = "Component creation failed due to object-level permissions violation" + logger.debug(msg) + form.add_error(None, msg) + if not form.errors: msg = "Added {} {} to {} {}.".format( len(new_components), diff --git a/netbox/virtualization/api/nested_serializers.py b/netbox/virtualization/api/nested_serializers.py index 47b7e6442..de56e6e6a 100644 --- a/netbox/virtualization/api/nested_serializers.py +++ b/netbox/virtualization/api/nested_serializers.py @@ -8,7 +8,7 @@ __all__ = [ 'NestedClusterGroupSerializer', 'NestedClusterSerializer', 'NestedClusterTypeSerializer', - 'NestedInterfaceSerializer', + 'NestedVMInterfaceSerializer', 'NestedVirtualMachineSerializer', ] @@ -56,8 +56,8 @@ class NestedVirtualMachineSerializer(WritableNestedSerializer): fields = ['id', 'url', 'name'] -class NestedInterfaceSerializer(WritableNestedSerializer): - url = serializers.HyperlinkedIdentityField(view_name='virtualization-api:interface-detail') +class NestedVMInterfaceSerializer(WritableNestedSerializer): + url = serializers.HyperlinkedIdentityField(view_name='virtualization-api:vminterface-detail') virtual_machine = NestedVirtualMachineSerializer(read_only=True) class Meta: diff --git a/netbox/virtualization/api/serializers.py b/netbox/virtualization/api/serializers.py index 3cca95b22..5698791f8 100644 --- a/netbox/virtualization/api/serializers.py +++ b/netbox/virtualization/api/serializers.py @@ -1,17 +1,16 @@ from drf_yasg.utils import swagger_serializer_method from rest_framework import serializers -from taggit_serializer.serializers import TaggitSerializer, TagListSerializerField from dcim.api.nested_serializers import NestedDeviceRoleSerializer, NestedPlatformSerializer, NestedSiteSerializer -from dcim.choices import InterfaceModeChoices, InterfaceTypeChoices -from dcim.models import Interface +from dcim.choices import InterfaceModeChoices from extras.api.customfields import CustomFieldModelSerializer +from extras.api.serializers import TaggedObjectSerializer from ipam.api.nested_serializers import NestedIPAddressSerializer, NestedVLANSerializer from ipam.models import VLAN from tenancy.api.nested_serializers import NestedTenantSerializer from utilities.api import ChoiceField, SerializedPKRelatedField, ValidatedModelSerializer from virtualization.choices import * -from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface from .nested_serializers import * @@ -35,12 +34,11 @@ class ClusterGroupSerializer(ValidatedModelSerializer): fields = ['id', 'name', 'slug', 'description', 'cluster_count'] -class ClusterSerializer(TaggitSerializer, CustomFieldModelSerializer): +class ClusterSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): type = NestedClusterTypeSerializer() group = NestedClusterGroupSerializer(required=False, allow_null=True) tenant = NestedTenantSerializer(required=False, allow_null=True) site = NestedSiteSerializer(required=False, allow_null=True) - tags = TagListSerializerField(required=False) device_count = serializers.IntegerField(read_only=True) virtualmachine_count = serializers.IntegerField(read_only=True) @@ -56,7 +54,7 @@ class ClusterSerializer(TaggitSerializer, CustomFieldModelSerializer): # Virtual machines # -class VirtualMachineSerializer(TaggitSerializer, CustomFieldModelSerializer): +class VirtualMachineSerializer(TaggedObjectSerializer, CustomFieldModelSerializer): status = ChoiceField(choices=VirtualMachineStatusChoices, required=False) site = NestedSiteSerializer(read_only=True) cluster = NestedClusterSerializer() @@ -66,7 +64,6 @@ class VirtualMachineSerializer(TaggitSerializer, CustomFieldModelSerializer): primary_ip = NestedIPAddressSerializer(read_only=True) primary_ip4 = NestedIPAddressSerializer(required=False, allow_null=True) primary_ip6 = NestedIPAddressSerializer(required=False, allow_null=True) - tags = TagListSerializerField(required=False) class Meta: model = VirtualMachine @@ -97,9 +94,8 @@ class VirtualMachineWithConfigContextSerializer(VirtualMachineSerializer): # VM interfaces # -class InterfaceSerializer(TaggitSerializer, ValidatedModelSerializer): +class VMInterfaceSerializer(TaggedObjectSerializer, ValidatedModelSerializer): virtual_machine = NestedVirtualMachineSerializer() - type = ChoiceField(choices=VMInterfaceTypeChoices, default=VMInterfaceTypeChoices.TYPE_VIRTUAL, required=False) mode = ChoiceField(choices=InterfaceModeChoices, allow_blank=True, required=False) untagged_vlan = NestedVLANSerializer(required=False, allow_null=True) tagged_vlans = SerializedPKRelatedField( @@ -108,11 +104,10 @@ class InterfaceSerializer(TaggitSerializer, ValidatedModelSerializer): required=False, many=True ) - tags = TagListSerializerField(required=False) class Meta: - model = Interface + model = VMInterface fields = [ - 'id', 'virtual_machine', 'name', 'type', 'enabled', 'mtu', 'mac_address', 'description', 'mode', - 'untagged_vlan', 'tagged_vlans', 'tags', + 'id', 'virtual_machine', 'name', 'enabled', 'mtu', 'mac_address', 'description', 'mode', 'untagged_vlan', + 'tagged_vlans', 'tags', ] diff --git a/netbox/virtualization/api/urls.py b/netbox/virtualization/api/urls.py index c237f1e68..3f6c56a48 100644 --- a/netbox/virtualization/api/urls.py +++ b/netbox/virtualization/api/urls.py @@ -21,7 +21,7 @@ router.register('clusters', views.ClusterViewSet) # VirtualMachines router.register('virtual-machines', views.VirtualMachineViewSet) -router.register('interfaces', views.InterfaceViewSet) +router.register('interfaces', views.VMInterfaceViewSet) app_name = 'virtualization-api' urlpatterns = router.urls diff --git a/netbox/virtualization/api/views.py b/netbox/virtualization/api/views.py index 2a1d7c3a9..f2a689f12 100644 --- a/netbox/virtualization/api/views.py +++ b/netbox/virtualization/api/views.py @@ -1,11 +1,11 @@ from django.db.models import Count -from dcim.models import Device, Interface +from dcim.models import Device from extras.api.views import CustomFieldModelViewSet from utilities.api import ModelViewSet from utilities.utils import get_subquery from virtualization import filters -from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface from . import serializers @@ -71,18 +71,11 @@ class VirtualMachineViewSet(CustomFieldModelViewSet): return serializers.VirtualMachineWithConfigContextSerializer -class InterfaceViewSet(ModelViewSet): - queryset = Interface.objects.filter( +class VMInterfaceViewSet(ModelViewSet): + queryset = VMInterface.objects.filter( virtual_machine__isnull=False ).prefetch_related( 'virtual_machine', 'tags' ) - serializer_class = serializers.InterfaceSerializer - filterset_class = filters.InterfaceFilterSet - - def get_serializer_class(self): - request = self.get_serializer_context()['request'] - if request.query_params.get('brief', False): - # Override get_serializer_for_model(), which will return the DCIM NestedInterfaceSerializer - return serializers.NestedInterfaceSerializer - return serializers.InterfaceSerializer + serializer_class = serializers.VMInterfaceSerializer + filterset_class = filters.VMInterfaceFilterSet diff --git a/netbox/virtualization/choices.py b/netbox/virtualization/choices.py index 1dae88e1d..3795ddb76 100644 --- a/netbox/virtualization/choices.py +++ b/netbox/virtualization/choices.py @@ -1,4 +1,3 @@ -from dcim.choices import InterfaceTypeChoices from utilities.choices import ChoiceSet @@ -29,16 +28,3 @@ class VirtualMachineStatusChoices(ChoiceSet): STATUS_ACTIVE: 1, STATUS_STAGED: 3, } - - -# -# Interface types (for VirtualMachines) -# - -class VMInterfaceTypeChoices(ChoiceSet): - - TYPE_VIRTUAL = InterfaceTypeChoices.TYPE_VIRTUAL - - CHOICES = ( - (TYPE_VIRTUAL, 'Virtual'), - ) diff --git a/netbox/virtualization/filters.py b/netbox/virtualization/filters.py index a54b6ab28..33ca44a22 100644 --- a/netbox/virtualization/filters.py +++ b/netbox/virtualization/filters.py @@ -1,7 +1,7 @@ import django_filters from django.db.models import Q -from dcim.models import DeviceRole, Interface, Platform, Region, Site +from dcim.models import DeviceRole, Platform, Region, Site from extras.filters import CustomFieldFilterSet, CreatedUpdatedFilterSet, LocalConfigContextFilterSet from tenancy.filters import TenancyFilterSet from utilities.filters import ( @@ -9,14 +9,14 @@ from utilities.filters import ( TreeNodeMultipleChoiceFilter, ) from .choices import * -from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface __all__ = ( 'ClusterFilterSet', 'ClusterGroupFilterSet', 'ClusterTypeFilterSet', - 'InterfaceFilterSet', 'VirtualMachineFilterSet', + 'VMInterfaceFilterSet', ) @@ -40,45 +40,45 @@ class ClusterFilterSet(BaseFilterSet, TenancyFilterSet, CustomFieldFilterSet, Cr label='Search', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='site__region', lookup_expr='in', to_field_name='slug', label='Region (slug)', ) site_id = django_filters.ModelMultipleChoiceFilter( - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) group_id = django_filters.ModelMultipleChoiceFilter( - queryset=ClusterGroup.objects.all(), + queryset=ClusterGroup.objects.unrestricted(), label='Parent group (ID)', ) group = django_filters.ModelMultipleChoiceFilter( field_name='group__slug', - queryset=ClusterGroup.objects.all(), + queryset=ClusterGroup.objects.unrestricted(), to_field_name='slug', label='Parent group (slug)', ) type_id = django_filters.ModelMultipleChoiceFilter( - queryset=ClusterType.objects.all(), + queryset=ClusterType.objects.unrestricted(), label='Cluster type (ID)', ) type = django_filters.ModelMultipleChoiceFilter( field_name='type__slug', - queryset=ClusterType.objects.all(), + queryset=ClusterType.objects.unrestricted(), to_field_name='slug', label='Cluster type (slug)', ) @@ -114,38 +114,38 @@ class VirtualMachineFilterSet( ) cluster_group_id = django_filters.ModelMultipleChoiceFilter( field_name='cluster__group', - queryset=ClusterGroup.objects.all(), + queryset=ClusterGroup.objects.unrestricted(), label='Cluster group (ID)', ) cluster_group = django_filters.ModelMultipleChoiceFilter( field_name='cluster__group__slug', - queryset=ClusterGroup.objects.all(), + queryset=ClusterGroup.objects.unrestricted(), to_field_name='slug', label='Cluster group (slug)', ) cluster_type_id = django_filters.ModelMultipleChoiceFilter( field_name='cluster__type', - queryset=ClusterType.objects.all(), + queryset=ClusterType.objects.unrestricted(), label='Cluster type (ID)', ) cluster_type = django_filters.ModelMultipleChoiceFilter( field_name='cluster__type__slug', - queryset=ClusterType.objects.all(), + queryset=ClusterType.objects.unrestricted(), to_field_name='slug', label='Cluster type (slug)', ) cluster_id = django_filters.ModelMultipleChoiceFilter( - queryset=Cluster.objects.all(), + queryset=Cluster.objects.unrestricted(), label='Cluster (ID)', ) region_id = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='cluster__site__region', lookup_expr='in', label='Region (ID)', ) region = TreeNodeMultipleChoiceFilter( - queryset=Region.objects.all(), + queryset=Region.objects.unrestricted(), field_name='cluster__site__region', lookup_expr='in', to_field_name='slug', @@ -153,32 +153,32 @@ class VirtualMachineFilterSet( ) site_id = django_filters.ModelMultipleChoiceFilter( field_name='cluster__site', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), label='Site (ID)', ) site = django_filters.ModelMultipleChoiceFilter( field_name='cluster__site__slug', - queryset=Site.objects.all(), + queryset=Site.objects.unrestricted(), to_field_name='slug', label='Site (slug)', ) role_id = django_filters.ModelMultipleChoiceFilter( - queryset=DeviceRole.objects.all(), + queryset=DeviceRole.objects.unrestricted(), label='Role (ID)', ) role = django_filters.ModelMultipleChoiceFilter( field_name='role__slug', - queryset=DeviceRole.objects.all(), + queryset=DeviceRole.objects.unrestricted(), to_field_name='slug', label='Role (slug)', ) platform_id = django_filters.ModelMultipleChoiceFilter( - queryset=Platform.objects.all(), + queryset=Platform.objects.unrestricted(), label='Platform (ID)', ) platform = django_filters.ModelMultipleChoiceFilter( field_name='platform__slug', - queryset=Platform.objects.all(), + queryset=Platform.objects.unrestricted(), to_field_name='slug', label='Platform (slug)', ) @@ -201,19 +201,19 @@ class VirtualMachineFilterSet( ) -class InterfaceFilterSet(BaseFilterSet): +class VMInterfaceFilterSet(BaseFilterSet): q = django_filters.CharFilter( method='search', label='Search', ) virtual_machine_id = django_filters.ModelMultipleChoiceFilter( field_name='virtual_machine', - queryset=VirtualMachine.objects.all(), + queryset=VirtualMachine.objects.unrestricted(), label='Virtual machine (ID)', ) virtual_machine = django_filters.ModelMultipleChoiceFilter( field_name='virtual_machine__name', - queryset=VirtualMachine.objects.all(), + queryset=VirtualMachine.objects.unrestricted(), to_field_name='name', label='Virtual machine', ) @@ -222,7 +222,7 @@ class InterfaceFilterSet(BaseFilterSet): ) class Meta: - model = Interface + model = VMInterface fields = ['id', 'name', 'enabled', 'mtu'] def search(self, queryset, name, value): diff --git a/netbox/virtualization/forms.py b/netbox/virtualization/forms.py index 0983b2432..ce6eea1e8 100644 --- a/netbox/virtualization/forms.py +++ b/netbox/virtualization/forms.py @@ -1,25 +1,25 @@ from django import forms from django.core.exceptions import ValidationError -from taggit.forms import TagField from dcim.choices import InterfaceModeChoices from dcim.constants import INTERFACE_MTU_MAX, INTERFACE_MTU_MIN from dcim.forms import INTERFACE_MODE_HELP_TEXT -from dcim.models import Device, DeviceRole, Interface, Platform, Rack, Region, Site +from dcim.models import Device, DeviceRole, Platform, Rack, Region, Site from extras.forms import ( AddRemoveTagsForm, CustomFieldBulkEditForm, CustomFieldModelCSVForm, CustomFieldModelForm, CustomFieldFilterForm, ) +from extras.models import Tag from ipam.models import IPAddress, VLAN from tenancy.forms import TenancyFilterForm, TenancyForm from tenancy.models import Tenant from utilities.forms import ( add_blank_choice, APISelect, APISelectMultiple, BootstrapMixin, BulkEditForm, BulkEditNullBooleanSelect, - CommentField, ConfirmationForm, CSVChoiceField, CSVModelChoiceField, CSVModelForm, DynamicModelChoiceField, - DynamicModelMultipleChoiceField, ExpandableNameField, form_from_model, JSONField, SlugField, SmallTextarea, - StaticSelect2, StaticSelect2Multiple, TagFilterField, + BulkRenameForm, CommentField, ConfirmationForm, CSVChoiceField, CSVModelChoiceField, CSVModelForm, + DynamicModelChoiceField, DynamicModelMultipleChoiceField, ExpandableNameField, form_from_model, JSONField, + SlugField, SmallTextarea, StaticSelect2, StaticSelect2Multiple, TagFilterField, BOOLEAN_WITH_BLANK_CHOICES, ) from .choices import * -from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface # @@ -83,7 +83,8 @@ class ClusterForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): required=False ) comments = CommentField() - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -312,13 +313,14 @@ class VirtualMachineForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): queryset=Platform.objects.all(), required=False ) - tags = TagField( - required=False - ) local_context_data = JSONField( required=False, label='' ) + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), + required=False + ) class Meta: model = VirtualMachine @@ -354,7 +356,8 @@ class VirtualMachineForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): ip_choices = [(None, '---------')] # Collect interface IPs interface_ips = IPAddress.objects.prefetch_related('interface').filter( - address__family=family, interface__virtual_machine=self.instance + address__family=family, + vminterface__in=self.instance.interfaces.values_list('id', flat=True) ) if interface_ips: ip_choices.append( @@ -364,7 +367,8 @@ class VirtualMachineForm(BootstrapMixin, TenancyForm, CustomFieldModelForm): ) # Collect NAT IPs nat_ips = IPAddress.objects.prefetch_related('nat_inside').filter( - address__family=family, nat_inside__interface__virtual_machine=self.instance + address__family=family, + nat_inside__vminterface__in=self.instance.interfaces.values_list('id', flat=True) ) if nat_ips: ip_choices.append( @@ -567,7 +571,7 @@ class VirtualMachineFilterForm(BootstrapMixin, TenancyFilterForm, CustomFieldFil # VM interfaces # -class InterfaceForm(BootstrapMixin, forms.ModelForm): +class VMInterfaceForm(BootstrapMixin, forms.ModelForm): untagged_vlan = DynamicModelChoiceField( queryset=VLAN.objects.all(), required=False, @@ -590,19 +594,19 @@ class InterfaceForm(BootstrapMixin, forms.ModelForm): }, ) ) - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) class Meta: - model = Interface + model = VMInterface fields = [ - 'virtual_machine', 'name', 'type', 'enabled', 'mac_address', 'mtu', 'description', 'mode', 'tags', - 'untagged_vlan', 'tagged_vlans', + 'virtual_machine', 'name', 'enabled', 'mac_address', 'mtu', 'description', 'mode', 'tags', 'untagged_vlan', + 'tagged_vlans', ] widgets = { 'virtual_machine': forms.HiddenInput(), - 'type': forms.HiddenInput(), 'mode': StaticSelect2() } labels = { @@ -615,10 +619,13 @@ class InterfaceForm(BootstrapMixin, forms.ModelForm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + virtual_machine = VirtualMachine.objects.get( + pk=self.initial.get('virtual_machine') or self.data.get('virtual_machine') + ) + # Add current site to VLANs query params - site = getattr(self.instance.parent, 'site', None) - if site is not None: - # Add current site to VLANs query params + site = virtual_machine.site + if site: self.fields['untagged_vlan'].widget.add_additional_query_param('site_id', site.pk) self.fields['tagged_vlans'].widget.add_additional_query_param('site_id', site.pk) @@ -639,19 +646,13 @@ class InterfaceForm(BootstrapMixin, forms.ModelForm): self.cleaned_data['tagged_vlans'] = [] -class InterfaceCreateForm(BootstrapMixin, forms.Form): - virtual_machine = forms.ModelChoiceField( - queryset=VirtualMachine.objects.all(), - widget=forms.HiddenInput() +class VMInterfaceCreateForm(BootstrapMixin, forms.Form): + virtual_machine = DynamicModelChoiceField( + queryset=VirtualMachine.objects.all() ) name_pattern = ExpandableNameField( label='Name' ) - type = forms.ChoiceField( - choices=VMInterfaceTypeChoices, - initial=VMInterfaceTypeChoices.TYPE_VIRTUAL, - widget=forms.HiddenInput() - ) enabled = forms.BooleanField( required=False, initial=True @@ -697,7 +698,8 @@ class InterfaceCreateForm(BootstrapMixin, forms.Form): }, ) ) - tags = TagField( + tags = DynamicModelMultipleChoiceField( + queryset=Tag.objects.all(), required=False ) @@ -708,16 +710,39 @@ class InterfaceCreateForm(BootstrapMixin, forms.Form): pk=self.initial.get('virtual_machine') or self.data.get('virtual_machine') ) - site = getattr(virtual_machine.cluster, 'site', None) - if site is not None: - # Add current site to VLANs query params + # Add current site to VLANs query params + site = virtual_machine.site + if site: self.fields['untagged_vlan'].widget.add_additional_query_param('site_id', site.pk) self.fields['tagged_vlans'].widget.add_additional_query_param('site_id', site.pk) -class InterfaceBulkEditForm(BootstrapMixin, BulkEditForm): +class VMInterfaceCSVForm(CSVModelForm): + virtual_machine = CSVModelChoiceField( + queryset=VirtualMachine.objects.all(), + to_field_name='name' + ) + mode = CSVChoiceField( + choices=InterfaceModeChoices, + required=False, + help_text='IEEE 802.1Q operational mode (for L2 interfaces)' + ) + + class Meta: + model = VMInterface + fields = VMInterface.csv_headers + + def clean_enabled(self): + # Make sure enabled is True when it's not included in the uploaded data + if 'enabled' not in self.data: + return True + else: + return self.cleaned_data['enabled'] + + +class VMInterfaceBulkEditForm(BootstrapMixin, BulkEditForm): pk = forms.ModelMultipleChoiceField( - queryset=Interface.objects.all(), + queryset=VMInterface.objects.all(), widget=forms.MultipleHiddenInput() ) virtual_machine = forms.ModelChoiceField( @@ -785,6 +810,24 @@ class InterfaceBulkEditForm(BootstrapMixin, BulkEditForm): self.fields['tagged_vlans'].widget.add_additional_query_param('site_id', site.pk) +class VMInterfaceBulkRenameForm(BulkRenameForm): + pk = forms.ModelMultipleChoiceField( + queryset=VMInterface.objects.all(), + widget=forms.MultipleHiddenInput() + ) + + +class VMInterfaceFilterForm(forms.Form): + model = VMInterface + enabled = forms.NullBooleanField( + required=False, + widget=StaticSelect2( + choices=BOOLEAN_WITH_BLANK_CHOICES + ) + ) + tag = TagFilterField(model) + + # # Bulk VirtualMachine component creation # @@ -804,12 +847,8 @@ class VirtualMachineBulkAddComponentForm(BootstrapMixin, forms.Form): return ','.join(self.cleaned_data.get('tags')) -class InterfaceBulkCreateForm( - form_from_model(Interface, ['enabled', 'mtu', 'description', 'tags']), +class VMInterfaceBulkCreateForm( + form_from_model(VMInterface, ['enabled', 'mtu', 'description', 'tags']), VirtualMachineBulkAddComponentForm ): - type = forms.ChoiceField( - choices=VMInterfaceTypeChoices, - initial=VMInterfaceTypeChoices.TYPE_VIRTUAL, - widget=forms.HiddenInput() - ) + pass diff --git a/netbox/virtualization/migrations/0015_vminterface.py b/netbox/virtualization/migrations/0015_vminterface.py new file mode 100644 index 000000000..6c5207226 --- /dev/null +++ b/netbox/virtualization/migrations/0015_vminterface.py @@ -0,0 +1,44 @@ +# Generated by Django 3.0.6 on 2020-06-18 20:21 + +import dcim.fields +import django.core.validators +from django.db import migrations, models +import django.db.models.deletion +import taggit.managers +import utilities.fields +import utilities.ordering +import utilities.query_functions + + +class Migration(migrations.Migration): + + dependencies = [ + ('ipam', '0036_standardize_description'), + ('extras', '0042_customfield_manager'), + ('virtualization', '0014_standardize_description'), + ] + + operations = [ + migrations.CreateModel( + name='VMInterface', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False)), + ('name', models.CharField(max_length=64)), + ('_name', utilities.fields.NaturalOrderingField('name', blank=True, max_length=100, naturalize_function=utilities.ordering.naturalize_interface)), + ('enabled', models.BooleanField(default=True)), + ('mac_address', dcim.fields.MACAddressField(blank=True, null=True)), + ('mtu', models.PositiveIntegerField(blank=True, null=True, validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(65536)])), + ('mode', models.CharField(blank=True, max_length=50)), + ('description', models.CharField(blank=True, max_length=200)), + ('tagged_vlans', models.ManyToManyField(blank=True, related_name='vminterfaces_as_tagged', to='ipam.VLAN')), + ('tags', taggit.managers.TaggableManager(related_name='vminterface', through='extras.TaggedItem', to='extras.Tag')), + ('untagged_vlan', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='vminterfaces_as_untagged', to='ipam.VLAN')), + ('virtual_machine', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='interfaces', to='virtualization.VirtualMachine')), + ], + options={ + 'ordering': ('virtual_machine', utilities.query_functions.CollateAsChar('_name')), + 'unique_together': {('virtual_machine', 'name')}, + 'verbose_name': 'interface', + }, + ), + ] diff --git a/netbox/virtualization/migrations/0016_replicate_interfaces.py b/netbox/virtualization/migrations/0016_replicate_interfaces.py new file mode 100644 index 000000000..d6c0b0217 --- /dev/null +++ b/netbox/virtualization/migrations/0016_replicate_interfaces.py @@ -0,0 +1,69 @@ +import sys + +from django.db import migrations + + +def replicate_interfaces(apps, schema_editor): + ContentType = apps.get_model('contenttypes', 'ContentType') + TaggedItem = apps.get_model('extras', 'TaggedItem') + Interface = apps.get_model('dcim', 'Interface') + IPAddress = apps.get_model('ipam', 'IPAddress') + VMInterface = apps.get_model('virtualization', 'VMInterface') + + interface_ct = ContentType.objects.get_for_model(Interface) + vminterface_ct = ContentType.objects.get_for_model(VMInterface) + + # Replicate dcim.Interface instances assigned to VirtualMachines + original_interfaces = Interface.objects.filter(virtual_machine__isnull=False) + for interface in original_interfaces: + vminterface = VMInterface( + virtual_machine=interface.virtual_machine, + name=interface.name, + enabled=interface.enabled, + mac_address=interface.mac_address, + mtu=interface.mtu, + mode=interface.mode, + description=interface.description, + untagged_vlan=interface.untagged_vlan, + ) + vminterface.save() + + # Copy tagged VLANs + vminterface.tagged_vlans.set(interface.tagged_vlans.all()) + + # Reassign tags to the new instance + TaggedItem.objects.filter( + content_type=interface_ct, object_id=interface.pk + ).update( + content_type=vminterface_ct, object_id=vminterface.pk + ) + + # Update any assigned IPAddresses + IPAddress.objects.filter(assigned_object_id=interface.pk).update( + assigned_object_type=vminterface_ct, + assigned_object_id=vminterface.pk + ) + + replicated_count = VMInterface.objects.count() + if 'test' not in sys.argv: + print(f"\n Replicated {replicated_count} interfaces ", end='', flush=True) + + # Verify that all interfaces have been replicated + assert replicated_count == original_interfaces.count(), "Replicated interfaces count does not match original count!" + + # Delete original VM interfaces + original_interfaces.delete() + + +class Migration(migrations.Migration): + + dependencies = [ + ('ipam', '0037_ipaddress_assignment'), + ('virtualization', '0015_vminterface'), + ] + + operations = [ + migrations.RunPython( + code=replicate_interfaces + ), + ] diff --git a/netbox/virtualization/models.py b/netbox/virtualization/models.py index 3daeff013..4a753561a 100644 --- a/netbox/virtualization/models.py +++ b/netbox/virtualization/models.py @@ -5,10 +5,14 @@ from django.db import models from django.urls import reverse from taggit.managers import TaggableManager -from dcim.models import Device -from extras.models import ConfigContextModel, CustomFieldModel, TaggedItem +from dcim.choices import InterfaceModeChoices +from dcim.models import BaseInterface, Device +from extras.models import ConfigContextModel, CustomFieldModel, ObjectChange, TaggedItem from extras.utils import extras_features from utilities.models import ChangeLoggedModel +from utilities.query_functions import CollateAsChar +from utilities.querysets import RestrictedQuerySet +from utilities.utils import serialize_object from .choices import * @@ -17,6 +21,7 @@ __all__ = ( 'ClusterGroup', 'ClusterType', 'VirtualMachine', + 'VMInterface', ) @@ -40,6 +45,8 @@ class ClusterType(ChangeLoggedModel): blank=True ) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'slug', 'description'] class Meta: @@ -79,6 +86,8 @@ class ClusterGroup(ChangeLoggedModel): blank=True ) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'slug', 'description'] class Meta: @@ -145,9 +154,10 @@ class Cluster(ChangeLoggedModel, CustomFieldModel): content_type_field='obj_type', object_id_field='obj_id' ) - tags = TaggableManager(through=TaggedItem) + objects = RestrictedQuerySet.as_manager() + csv_headers = ['name', 'type', 'group', 'site', 'comments'] clone_fields = [ 'type', 'group', 'tenant', 'site', @@ -269,9 +279,10 @@ class VirtualMachine(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): content_type_field='obj_type', object_id_field='obj_id' ) - tags = TaggableManager(through=TaggedItem) + objects = RestrictedQuerySet.as_manager() + csv_headers = [ 'name', 'status', 'role', 'cluster', 'tenant', 'platform', 'vcpus', 'memory', 'disk', 'comments', ] @@ -363,3 +374,111 @@ class VirtualMachine(ChangeLoggedModel, ConfigContextModel, CustomFieldModel): @property def site(self): return self.cluster.site + + +# +# Interfaces +# + +@extras_features('graphs', 'export_templates', 'webhooks') +class VMInterface(BaseInterface): + virtual_machine = models.ForeignKey( + to='virtualization.VirtualMachine', + on_delete=models.CASCADE, + related_name='interfaces' + ) + description = models.CharField( + max_length=200, + blank=True + ) + untagged_vlan = models.ForeignKey( + to='ipam.VLAN', + on_delete=models.SET_NULL, + related_name='vminterfaces_as_untagged', + null=True, + blank=True, + verbose_name='Untagged VLAN' + ) + tagged_vlans = models.ManyToManyField( + to='ipam.VLAN', + related_name='vminterfaces_as_tagged', + blank=True, + verbose_name='Tagged VLANs' + ) + ip_addresses = GenericRelation( + to='ipam.IPAddress', + content_type_field='assigned_object_type', + object_id_field='assigned_object_id', + related_query_name='vminterface' + ) + tags = TaggableManager( + through=TaggedItem, + related_name='vminterface' + ) + + objects = RestrictedQuerySet.as_manager() + + csv_headers = [ + 'virtual_machine', 'name', 'enabled', 'mac_address', 'mtu', 'description', 'mode', + ] + + class Meta: + verbose_name = 'interface' + ordering = ('virtual_machine', CollateAsChar('_name')) + unique_together = ('virtual_machine', 'name') + + def __str__(self): + return self.name + + def get_absolute_url(self): + return reverse('virtualization:vminterface', kwargs={'pk': self.pk}) + + def to_csv(self): + return ( + self.virtual_machine.name, + self.name, + self.enabled, + self.mac_address, + self.mtu, + self.description, + self.get_mode_display(), + ) + + def clean(self): + + # Validate untagged VLAN + if self.untagged_vlan and self.untagged_vlan.site not in [self.virtual_machine.site, None]: + raise ValidationError({ + 'untagged_vlan': "The untagged VLAN ({}) must belong to the same site as the interface's parent " + "virtual machine, or it must be global".format(self.untagged_vlan) + }) + + def save(self, *args, **kwargs): + + # Remove untagged VLAN assignment for non-802.1Q interfaces + if self.mode is None: + self.untagged_vlan = None + + # Only "tagged" interfaces may have tagged VLANs assigned. ("tagged all" implies all VLANs are assigned.) + if self.pk and self.mode != InterfaceModeChoices.MODE_TAGGED: + self.tagged_vlans.clear() + + return super().save(*args, **kwargs) + + def to_objectchange(self, action): + # Annotate the parent VirtualMachine + return ObjectChange( + changed_object=self, + object_repr=str(self), + action=action, + related_object=self.virtual_machine, + object_data=serialize_object(self) + ) + + @property + def parent(self): + return self.virtual_machine + + @property + def count_ipaddresses(self): + return self.ip_addresses.count() diff --git a/netbox/virtualization/tables.py b/netbox/virtualization/tables.py index 077add945..de319361c 100644 --- a/netbox/virtualization/tables.py +++ b/netbox/virtualization/tables.py @@ -1,10 +1,9 @@ import django_tables2 as tables from django_tables2.utils import Accessor -from dcim.models import Interface from tenancy.tables import COL_TENANT -from utilities.tables import BaseTable, TagColumn, ToggleColumn -from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from utilities.tables import BaseTable, ColoredLabelColumn, TagColumn, ToggleColumn +from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface CLUSTERTYPE_ACTIONS = """ @@ -28,10 +27,6 @@ VIRTUALMACHINE_STATUS = """ {{ record.get_status_display }} """ -VIRTUALMACHINE_ROLE = """ -{% if record.role %}{% else %}—{% endif %} -""" - VIRTUALMACHINE_PRIMARY_IP = """ {{ record.primary_ip6.address.ip|default:"" }} {% if record.primary_ip6 and record.primary_ip4 %}
    {% endif %} @@ -132,9 +127,7 @@ class VirtualMachineTable(BaseTable): viewname='virtualization:cluster', args=[Accessor('cluster.pk')] ) - role = tables.TemplateColumn( - template_code=VIRTUALMACHINE_ROLE - ) + role = ColoredLabelColumn() tenant = tables.TemplateColumn( template_code=COL_TENANT ) @@ -179,8 +172,12 @@ class VirtualMachineDetailTable(VirtualMachineTable): # VM components # -class InterfaceTable(BaseTable): +class VMInterfaceTable(BaseTable): + virtual_machine = tables.LinkColumn() + name = tables.Column( + linkify=True + ) class Meta(BaseTable.Meta): - model = Interface - fields = ('name', 'enabled', 'description') + model = VMInterface + fields = ('virtual_machine', 'name', 'enabled', 'mac_address', 'mtu', 'description') diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index 8568e21e9..8d525f4fe 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -1,13 +1,10 @@ from django.urls import reverse -from netaddr import IPNetwork from rest_framework import status from dcim.choices import InterfaceModeChoices -from dcim.models import Interface -from ipam.models import IPAddress, VLAN -from utilities.testing import APITestCase, disable_warnings -from virtualization.choices import * -from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from ipam.models import VLAN +from utilities.testing import APITestCase, APIViewTestCases +from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface class AppTest(APITestCase): @@ -20,488 +17,184 @@ class AppTest(APITestCase): self.assertEqual(response.status_code, 200) -class ClusterTypeTest(APITestCase): +class ClusterTypeTest(APIViewTestCases.APIViewTestCase): + model = ClusterType + brief_fields = ['cluster_count', 'id', 'name', 'slug', 'url'] + create_data = [ + { + 'name': 'Cluster Type 4', + 'slug': 'cluster-type-4', + }, + { + 'name': 'Cluster Type 5', + 'slug': 'cluster-type-5', + }, + { + 'name': 'Cluster Type 6', + 'slug': 'cluster-type-6', + }, + ] - def setUp(self): + @classmethod + def setUpTestData(cls): - super().setUp() - - self.clustertype1 = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') - self.clustertype2 = ClusterType.objects.create(name='Test Cluster Type 2', slug='test-cluster-type-2') - self.clustertype3 = ClusterType.objects.create(name='Test Cluster Type 3', slug='test-cluster-type-3') - - def test_get_clustertype(self): - - url = reverse('virtualization-api:clustertype-detail', kwargs={'pk': self.clustertype1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.clustertype1.name) - - def test_list_clustertypes(self): - - url = reverse('virtualization-api:clustertype-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_clustertypes_brief(self): - - url = reverse('virtualization-api:clustertype-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['cluster_count', 'id', 'name', 'slug', 'url'] + cluster_types = ( + ClusterType(name='Cluster Type 1', slug='cluster-type-1'), + ClusterType(name='Cluster Type 2', slug='cluster-type-2'), + ClusterType(name='Cluster Type 3', slug='cluster-type-3'), ) + ClusterType.objects.bulk_create(cluster_types) - def test_create_clustertype(self): - data = { - 'name': 'Test Cluster Type 4', - 'slug': 'test-cluster-type-4', - } +class ClusterGroupTest(APIViewTestCases.APIViewTestCase): + model = ClusterGroup + brief_fields = ['cluster_count', 'id', 'name', 'slug', 'url'] + create_data = [ + { + 'name': 'Cluster Group 4', + 'slug': 'cluster-type-4', + }, + { + 'name': 'Cluster Group 5', + 'slug': 'cluster-type-5', + }, + { + 'name': 'Cluster Group 6', + 'slug': 'cluster-type-6', + }, + ] - url = reverse('virtualization-api:clustertype-list') - response = self.client.post(url, data, format='json', **self.header) + @classmethod + def setUpTestData(cls): - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(ClusterType.objects.count(), 4) - clustertype4 = ClusterType.objects.get(pk=response.data['id']) - self.assertEqual(clustertype4.name, data['name']) - self.assertEqual(clustertype4.slug, data['slug']) + cluster_Groups = ( + ClusterGroup(name='Cluster Group 1', slug='cluster-type-1'), + ClusterGroup(name='Cluster Group 2', slug='cluster-type-2'), + ClusterGroup(name='Cluster Group 3', slug='cluster-type-3'), + ) + ClusterGroup.objects.bulk_create(cluster_Groups) - def test_create_clustertype_bulk(self): - data = [ +class ClusterTest(APIViewTestCases.APIViewTestCase): + model = Cluster + brief_fields = ['id', 'name', 'url', 'virtualmachine_count'] + + @classmethod + def setUpTestData(cls): + + cluster_types = ( + ClusterType(name='Cluster Type 1', slug='cluster-type-1'), + ClusterType(name='Cluster Type 2', slug='cluster-type-2'), + ) + ClusterType.objects.bulk_create(cluster_types) + + cluster_groups = ( + ClusterGroup(name='Cluster Group 1', slug='cluster-group-1'), + ClusterGroup(name='Cluster Group 2', slug='cluster-group-2'), + ) + ClusterGroup.objects.bulk_create(cluster_groups) + + clusters = ( + Cluster(name='Cluster 1', type=cluster_types[0], group=cluster_groups[0]), + Cluster(name='Cluster 2', type=cluster_types[0], group=cluster_groups[0]), + Cluster(name='Cluster 3', type=cluster_types[0], group=cluster_groups[0]), + ) + Cluster.objects.bulk_create(clusters) + + cls.create_data = [ { - 'name': 'Test Cluster Type 4', - 'slug': 'test-cluster-type-4', + 'name': 'Cluster 4', + 'type': cluster_types[1].pk, + 'group': cluster_groups[1].pk, }, { - 'name': 'Test Cluster Type 5', - 'slug': 'test-cluster-type-5', + 'name': 'Cluster 5', + 'type': cluster_types[1].pk, + 'group': cluster_groups[1].pk, }, { - 'name': 'Test Cluster Type 6', - 'slug': 'test-cluster-type-6', + 'name': 'Cluster 6', + 'type': cluster_types[1].pk, + 'group': cluster_groups[1].pk, }, ] - url = reverse('virtualization-api:clustertype-list') - response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(ClusterType.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) +class VirtualMachineTest(APIViewTestCases.APIViewTestCase): + model = VirtualMachine + brief_fields = ['id', 'name', 'url'] - def test_update_clustertype(self): + @classmethod + def setUpTestData(cls): + clustertype = ClusterType.objects.create(name='Cluster Type 1', slug='cluster-type-1') + clustergroup = ClusterGroup.objects.create(name='Cluster Group 1', slug='cluster-group-1') - data = { - 'name': 'Test Cluster Type X', - 'slug': 'test-cluster-type-x', - } - - url = reverse('virtualization-api:clustertype-detail', kwargs={'pk': self.clustertype1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(ClusterType.objects.count(), 3) - clustertype1 = ClusterType.objects.get(pk=response.data['id']) - self.assertEqual(clustertype1.name, data['name']) - self.assertEqual(clustertype1.slug, data['slug']) - - def test_delete_clustertype(self): - - url = reverse('virtualization-api:clustertype-detail', kwargs={'pk': self.clustertype1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(ClusterType.objects.count(), 2) - - -class ClusterGroupTest(APITestCase): - - def setUp(self): - - super().setUp() - - self.clustergroup1 = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1') - self.clustergroup2 = ClusterGroup.objects.create(name='Test Cluster Group 2', slug='test-cluster-group-2') - self.clustergroup3 = ClusterGroup.objects.create(name='Test Cluster Group 3', slug='test-cluster-group-3') - - def test_get_clustergroup(self): - - url = reverse('virtualization-api:clustergroup-detail', kwargs={'pk': self.clustergroup1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.clustergroup1.name) - - def test_list_clustergroups(self): - - url = reverse('virtualization-api:clustergroup-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_clustergroups_brief(self): - - url = reverse('virtualization-api:clustergroup-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['cluster_count', 'id', 'name', 'slug', 'url'] + clusters = ( + Cluster(name='Cluster 1', type=clustertype, group=clustergroup), + Cluster(name='Cluster 2', type=clustertype, group=clustergroup), ) + Cluster.objects.bulk_create(clusters) - def test_create_clustergroup(self): + virtual_machines = ( + VirtualMachine(name='Virtual Machine 1', cluster=clusters[0], local_context_data={'A': 1}), + VirtualMachine(name='Virtual Machine 2', cluster=clusters[0], local_context_data={'B': 2}), + VirtualMachine(name='Virtual Machine 3', cluster=clusters[0], local_context_data={'C': 3}), + ) + VirtualMachine.objects.bulk_create(virtual_machines) - data = { - 'name': 'Test Cluster Group 4', - 'slug': 'test-cluster-group-4', - } - - url = reverse('virtualization-api:clustergroup-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(ClusterGroup.objects.count(), 4) - clustergroup4 = ClusterGroup.objects.get(pk=response.data['id']) - self.assertEqual(clustergroup4.name, data['name']) - self.assertEqual(clustergroup4.slug, data['slug']) - - def test_create_clustergroup_bulk(self): - - data = [ + cls.create_data = [ { - 'name': 'Test Cluster Group 4', - 'slug': 'test-cluster-group-4', + 'name': 'Virtual Machine 4', + 'cluster': clusters[1].pk, }, { - 'name': 'Test Cluster Group 5', - 'slug': 'test-cluster-group-5', + 'name': 'Virtual Machine 5', + 'cluster': clusters[1].pk, }, { - 'name': 'Test Cluster Group 6', - 'slug': 'test-cluster-group-6', + 'name': 'Virtual Machine 6', + 'cluster': clusters[1].pk, }, ] - url = reverse('virtualization-api:clustergroup-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(ClusterGroup.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) - - def test_update_clustergroup(self): - - data = { - 'name': 'Test Cluster Group X', - 'slug': 'test-cluster-group-x', - } - - url = reverse('virtualization-api:clustergroup-detail', kwargs={'pk': self.clustergroup1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(ClusterGroup.objects.count(), 3) - clustergroup1 = ClusterGroup.objects.get(pk=response.data['id']) - self.assertEqual(clustergroup1.name, data['name']) - self.assertEqual(clustergroup1.slug, data['slug']) - - def test_delete_clustergroup(self): - - url = reverse('virtualization-api:clustergroup-detail', kwargs={'pk': self.clustergroup1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(ClusterGroup.objects.count(), 2) - - -class ClusterTest(APITestCase): - - def setUp(self): - - super().setUp() - - cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') - cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1') - - self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group) - self.cluster2 = Cluster.objects.create(name='Test Cluster 2', type=cluster_type, group=cluster_group) - self.cluster3 = Cluster.objects.create(name='Test Cluster 3', type=cluster_type, group=cluster_group) - - def test_get_cluster(self): - - url = reverse('virtualization-api:cluster-detail', kwargs={'pk': self.cluster1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.cluster1.name) - - def test_list_clusters(self): - - url = reverse('virtualization-api:cluster-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 3) - - def test_list_clusters_brief(self): - - url = reverse('virtualization-api:cluster-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['id', 'name', 'url', 'virtualmachine_count'] - ) - - def test_create_cluster(self): - - data = { - 'name': 'Test Cluster 4', - 'type': ClusterType.objects.first().pk, - 'group': ClusterGroup.objects.first().pk, - } - - url = reverse('virtualization-api:cluster-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cluster.objects.count(), 4) - cluster4 = Cluster.objects.get(pk=response.data['id']) - self.assertEqual(cluster4.name, data['name']) - self.assertEqual(cluster4.type.pk, data['type']) - self.assertEqual(cluster4.group.pk, data['group']) - - def test_create_cluster_bulk(self): - - data = [ - { - 'name': 'Test Cluster 4', - 'type': ClusterType.objects.first().pk, - 'group': ClusterGroup.objects.first().pk, - }, - { - 'name': 'Test Cluster 5', - 'type': ClusterType.objects.first().pk, - 'group': ClusterGroup.objects.first().pk, - }, - { - 'name': 'Test Cluster 6', - 'type': ClusterType.objects.first().pk, - 'group': ClusterGroup.objects.first().pk, - }, - ] - - url = reverse('virtualization-api:cluster-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Cluster.objects.count(), 6) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) - - def test_update_cluster(self): - - cluster_type2 = ClusterType.objects.create(name='Test Cluster Type 2', slug='test-cluster-type-2') - cluster_group2 = ClusterGroup.objects.create(name='Test Cluster Group 2', slug='test-cluster-group-2') - data = { - 'name': 'Test Cluster X', - 'type': cluster_type2.pk, - 'group': cluster_group2.pk, - } - - url = reverse('virtualization-api:cluster-detail', kwargs={'pk': self.cluster1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Cluster.objects.count(), 3) - cluster1 = Cluster.objects.get(pk=response.data['id']) - self.assertEqual(cluster1.name, data['name']) - self.assertEqual(cluster1.type.pk, data['type']) - self.assertEqual(cluster1.group.pk, data['group']) - - def test_delete_cluster(self): - - url = reverse('virtualization-api:cluster-detail', kwargs={'pk': self.cluster1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Cluster.objects.count(), 2) - - -class VirtualMachineTest(APITestCase): - - def setUp(self): - - super().setUp() - - cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') - cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1') - self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group) - - self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1) - self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1) - self.virtualmachine3 = VirtualMachine.objects.create(name='Test Virtual Machine 3', cluster=self.cluster1) - self.virtualmachine_with_context_data = VirtualMachine.objects.create( - name='VM with context data', - cluster=self.cluster1, - local_context_data={ - 'A': 1, - 'B': 2 - } - ) - - def test_get_virtualmachine(self): - - url = reverse('virtualization-api:virtualmachine-detail', kwargs={'pk': self.virtualmachine1.pk}) - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['name'], self.virtualmachine1.name) - - def test_list_virtualmachines(self): - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.get(url, **self.header) - - self.assertEqual(response.data['count'], 4) - - def test_list_virtualmachines_brief(self): - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.get('{}?brief=1'.format(url), **self.header) - - self.assertEqual( - sorted(response.data['results'][0]), - ['id', 'name', 'url'] - ) - - def test_create_virtualmachine(self): - - data = { - 'name': 'Test Virtual Machine 4', - 'cluster': self.cluster1.pk, - } - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 5) - virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id']) - self.assertEqual(virtualmachine4.name, data['name']) - self.assertEqual(virtualmachine4.cluster.pk, data['cluster']) - - def test_create_virtualmachine_without_cluster(self): - - data = { - 'name': 'Test Virtual Machine 4', - } - - url = reverse('virtualization-api:virtualmachine-list') - with disable_warnings('django.request'): - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) - self.assertEqual(VirtualMachine.objects.count(), 4) - - def test_create_virtualmachine_bulk(self): - - data = [ - { - 'name': 'Test Virtual Machine 4', - 'cluster': self.cluster1.pk, - }, - { - 'name': 'Test Virtual Machine 5', - 'cluster': self.cluster1.pk, - }, - { - 'name': 'Test Virtual Machine 6', - 'cluster': self.cluster1.pk, - }, - ] - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.post(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 7) - self.assertEqual(response.data[0]['name'], data[0]['name']) - self.assertEqual(response.data[1]['name'], data[1]['name']) - self.assertEqual(response.data[2]['name'], data[2]['name']) - - def test_update_virtualmachine(self): - - interface = Interface.objects.create(name='Test Interface 1', virtual_machine=self.virtualmachine1) - ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface) - ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface) - - cluster2 = Cluster.objects.create( - name='Test Cluster 2', - type=ClusterType.objects.first(), - group=ClusterGroup.objects.first() - ) - data = { - 'name': 'Test Virtual Machine X', - 'cluster': cluster2.pk, - 'primary_ip4': ip4_address.pk, - 'primary_ip6': ip6_address.pk, - } - - url = reverse('virtualization-api:virtualmachine-detail', kwargs={'pk': self.virtualmachine1.pk}) - response = self.client.put(url, data, format='json', **self.header) - - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(VirtualMachine.objects.count(), 4) - virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id']) - self.assertEqual(virtualmachine1.name, data['name']) - self.assertEqual(virtualmachine1.cluster.pk, data['cluster']) - self.assertEqual(virtualmachine1.primary_ip4.pk, data['primary_ip4']) - self.assertEqual(virtualmachine1.primary_ip6.pk, data['primary_ip6']) - - def test_delete_virtualmachine(self): - - url = reverse('virtualization-api:virtualmachine-detail', kwargs={'pk': self.virtualmachine1.pk}) - response = self.client.delete(url, **self.header) - - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(VirtualMachine.objects.count(), 3) - def test_config_context_included_by_default_in_list_view(self): + """ + Check that config context data is included by default in the virtual machines list. + """ + virtualmachine = VirtualMachine.objects.first() + url = '{}?id={}'.format(reverse('virtualization-api:virtualmachine-list'), virtualmachine.pk) + self.add_permissions('virtualization.view_virtualmachine') - url = reverse('virtualization-api:virtualmachine-list') - url = '{}?id={}'.format(url, self.virtualmachine_with_context_data.pk) response = self.client.get(url, **self.header) - self.assertEqual(response.data['results'][0].get('config_context', {}).get('A'), 1) def test_config_context_excluded(self): - + """ + Check that config context data can be excluded by passing ?exclude=config_context. + """ url = reverse('virtualization-api:virtualmachine-list') + '?exclude=config_context' - response = self.client.get(url, **self.header) + self.add_permissions('virtualization.view_virtualmachine') + response = self.client.get(url, **self.header) self.assertFalse('config_context' in response.data['results'][0]) def test_unique_name_per_cluster_constraint(self): - + """ + Check that creating a virtual machine with a duplicate name fails. + """ data = { - 'name': 'Test Virtual Machine 1', - 'cluster': self.cluster1.pk, + 'name': 'Virtual Machine 1', + 'cluster': Cluster.objects.first().pk, } - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.post(url, data, format='json', **self.header) + self.add_permissions('virtualization.add_virtualmachine') + response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) -class InterfaceTest(APITestCase): +# TODO: Standardize InterfaceTest (pending #4721) +class VMInterfaceTest(APITestCase): def setUp(self): @@ -510,20 +203,17 @@ class InterfaceTest(APITestCase): clustertype = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') cluster = Cluster.objects.create(name='Test Cluster 1', type=clustertype) self.virtualmachine = VirtualMachine.objects.create(cluster=cluster, name='Test VM 1') - self.interface1 = Interface.objects.create( + self.interface1 = VMInterface.objects.create( virtual_machine=self.virtualmachine, - name='Test Interface 1', - type=InterfaceTypeChoices.TYPE_VIRTUAL + name='Test Interface 1' ) - self.interface2 = Interface.objects.create( + self.interface2 = VMInterface.objects.create( virtual_machine=self.virtualmachine, - name='Test Interface 2', - type=InterfaceTypeChoices.TYPE_VIRTUAL + name='Test Interface 2' ) - self.interface3 = Interface.objects.create( + self.interface3 = VMInterface.objects.create( virtual_machine=self.virtualmachine, - name='Test Interface 3', - type=InterfaceTypeChoices.TYPE_VIRTUAL + name='Test Interface 3' ) self.vlan1 = VLAN.objects.create(name="Test VLAN 1", vid=1) @@ -531,47 +221,45 @@ class InterfaceTest(APITestCase): self.vlan3 = VLAN.objects.create(name="Test VLAN 3", vid=3) def test_get_interface(self): + url = reverse('virtualization-api:vminterface-detail', kwargs={'pk': self.interface1.pk}) + self.add_permissions('virtualization.view_vminterface') - url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) response = self.client.get(url, **self.header) - self.assertEqual(response.data['name'], self.interface1.name) def test_list_interfaces(self): + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.view_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 3) def test_list_interfaces_brief(self): + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.view_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.get('{}?brief=1'.format(url), **self.header) - self.assertEqual( sorted(response.data['results'][0]), ['id', 'name', 'url', 'virtual_machine'] ) def test_create_interface(self): - data = { 'virtual_machine': self.virtualmachine.pk, 'name': 'Test Interface 4', } + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.add_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Interface.objects.count(), 4) - interface4 = Interface.objects.get(pk=response.data['id']) + self.assertEqual(VMInterface.objects.count(), 4) + interface4 = VMInterface.objects.get(pk=response.data['id']) self.assertEqual(interface4.virtual_machine_id, data['virtual_machine']) self.assertEqual(interface4.name, data['name']) def test_create_interface_with_802_1q(self): - data = { 'virtual_machine': self.virtualmachine.pk, 'name': 'Test Interface 4', @@ -579,19 +267,18 @@ class InterfaceTest(APITestCase): 'untagged_vlan': self.vlan3.id, 'tagged_vlans': [self.vlan1.id, self.vlan2.id], } + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.add_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Interface.objects.count(), 4) + self.assertEqual(VMInterface.objects.count(), 4) self.assertEqual(response.data['virtual_machine']['id'], data['virtual_machine']) self.assertEqual(response.data['name'], data['name']) self.assertEqual(response.data['untagged_vlan']['id'], data['untagged_vlan']) self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans']) def test_create_interface_bulk(self): - data = [ { 'virtual_machine': self.virtualmachine.pk, @@ -606,18 +293,17 @@ class InterfaceTest(APITestCase): 'name': 'Test Interface 6', }, ] + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.add_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Interface.objects.count(), 6) + self.assertEqual(VMInterface.objects.count(), 6) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) def test_create_interface_802_1q_bulk(self): - data = [ { 'virtual_machine': self.virtualmachine.pk, @@ -641,36 +327,35 @@ class InterfaceTest(APITestCase): 'tagged_vlans': [self.vlan1.id], }, ] + url = reverse('virtualization-api:vminterface-list') + self.add_permissions('virtualization.add_vminterface') - url = reverse('virtualization-api:interface-list') response = self.client.post(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Interface.objects.count(), 6) + self.assertEqual(VMInterface.objects.count(), 6) for i in range(0, 3): self.assertEqual(response.data[i]['name'], data[i]['name']) self.assertEqual([v['id'] for v in response.data[i]['tagged_vlans']], data[i]['tagged_vlans']) self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan']) def test_update_interface(self): - data = { 'virtual_machine': self.virtualmachine.pk, 'name': 'Test Interface X', } + url = reverse('virtualization-api:vminterface-detail', kwargs={'pk': self.interface1.pk}) + self.add_permissions('virtualization.change_vminterface') - url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) response = self.client.put(url, data, format='json', **self.header) - self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Interface.objects.count(), 3) - interface1 = Interface.objects.get(pk=response.data['id']) + self.assertEqual(VMInterface.objects.count(), 3) + interface1 = VMInterface.objects.get(pk=response.data['id']) self.assertEqual(interface1.name, data['name']) def test_delete_interface(self): + url = reverse('virtualization-api:vminterface-detail', kwargs={'pk': self.interface1.pk}) + self.add_permissions('virtualization.delete_vminterface') - url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) response = self.client.delete(url, **self.header) - self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Interface.objects.count(), 2) + self.assertEqual(VMInterface.objects.count(), 2) diff --git a/netbox/virtualization/tests/test_filters.py b/netbox/virtualization/tests/test_filters.py index 51c7c6e8d..ad452ec51 100644 --- a/netbox/virtualization/tests/test_filters.py +++ b/netbox/virtualization/tests/test_filters.py @@ -1,10 +1,10 @@ from django.test import TestCase -from dcim.models import DeviceRole, Interface, Platform, Region, Site +from dcim.models import DeviceRole, Platform, Region, Site from tenancy.models import Tenant, TenantGroup from virtualization.choices import * from virtualization.filters import * -from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface class ClusterTypeTestCase(TestCase): @@ -260,11 +260,11 @@ class VirtualMachineTestCase(TestCase): VirtualMachine.objects.bulk_create(vms) interfaces = ( - Interface(virtual_machine=vms[0], name='Interface 1', mac_address='00-00-00-00-00-01'), - Interface(virtual_machine=vms[1], name='Interface 2', mac_address='00-00-00-00-00-02'), - Interface(virtual_machine=vms[2], name='Interface 3', mac_address='00-00-00-00-00-03'), + VMInterface(virtual_machine=vms[0], name='Interface 1', mac_address='00-00-00-00-00-01'), + VMInterface(virtual_machine=vms[1], name='Interface 2', mac_address='00-00-00-00-00-02'), + VMInterface(virtual_machine=vms[2], name='Interface 3', mac_address='00-00-00-00-00-03'), ) - Interface.objects.bulk_create(interfaces) + VMInterface.objects.bulk_create(interfaces) def test_id(self): params = {'id': self.queryset.values_list('pk', flat=True)[:2]} @@ -365,9 +365,9 @@ class VirtualMachineTestCase(TestCase): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) -class InterfaceTestCase(TestCase): - queryset = Interface.objects.all() - filterset = InterfaceFilterSet +class VMInterfaceTestCase(TestCase): + queryset = VMInterface.objects.all() + filterset = VMInterfaceFilterSet @classmethod def setUpTestData(cls): @@ -394,11 +394,11 @@ class InterfaceTestCase(TestCase): VirtualMachine.objects.bulk_create(vms) interfaces = ( - Interface(virtual_machine=vms[0], name='Interface 1', enabled=True, mtu=100, mac_address='00-00-00-00-00-01'), - Interface(virtual_machine=vms[1], name='Interface 2', enabled=True, mtu=200, mac_address='00-00-00-00-00-02'), - Interface(virtual_machine=vms[2], name='Interface 3', enabled=False, mtu=300, mac_address='00-00-00-00-00-03'), + VMInterface(virtual_machine=vms[0], name='Interface 1', enabled=True, mtu=100, mac_address='00-00-00-00-00-01'), + VMInterface(virtual_machine=vms[1], name='Interface 2', enabled=True, mtu=200, mac_address='00-00-00-00-00-02'), + VMInterface(virtual_machine=vms[2], name='Interface 3', enabled=False, mtu=300, mac_address='00-00-00-00-00-03'), ) - Interface.objects.bulk_create(interfaces) + VMInterface.objects.bulk_create(interfaces) def test_id(self): id_list = self.queryset.values_list('id', flat=True)[:2] diff --git a/netbox/virtualization/tests/test_views.py b/netbox/virtualization/tests/test_views.py index e7bb19285..408558779 100644 --- a/netbox/virtualization/tests/test_views.py +++ b/netbox/virtualization/tests/test_views.py @@ -1,11 +1,11 @@ from netaddr import EUI from dcim.choices import InterfaceModeChoices -from dcim.models import DeviceRole, Interface, Platform, Site +from dcim.models import DeviceRole, Platform, Site from ipam.models import VLAN from utilities.testing import ViewTestCases from virtualization.choices import * -from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface class ClusterGroupTestCase(ViewTestCases.OrganizationalObjectViewTestCase): @@ -90,6 +90,8 @@ class ClusterTestCase(ViewTestCases.PrimaryObjectViewTestCase): Cluster(name='Cluster 3', group=clustergroups[0], type=clustertypes[0], site=sites[0]), ]) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'name': 'Cluster X', 'group': clustergroups[1].pk, @@ -97,7 +99,7 @@ class ClusterTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'tenant': None, 'site': sites[1].pk, 'comments': 'Some comments', - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } cls.csv_data = ( @@ -148,6 +150,8 @@ class VirtualMachineTestCase(ViewTestCases.PrimaryObjectViewTestCase): VirtualMachine(name='Virtual Machine 3', cluster=clusters[0], role=deviceroles[0], platform=platforms[0]), ]) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'cluster': clusters[1].pk, 'tenant': None, @@ -161,7 +165,7 @@ class VirtualMachineTestCase(ViewTestCases.PrimaryObjectViewTestCase): 'memory': 32768, 'disk': 4000, 'comments': 'Some comments', - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], 'local_context_data': None, } @@ -185,19 +189,11 @@ class VirtualMachineTestCase(ViewTestCases.PrimaryObjectViewTestCase): } -class InterfaceTestCase( +class VMInterfaceTestCase( ViewTestCases.GetObjectViewTestCase, ViewTestCases.DeviceComponentViewTestCase, ): - model = Interface - - # Disable inapplicable tests - test_list_objects = None - test_import_objects = None - - def _get_base_url(self): - # Interface belongs to the DCIM app, so we have to override the base URL - return 'virtualization:interface_{}' + model = VMInterface @classmethod def setUpTestData(cls): @@ -212,10 +208,10 @@ class InterfaceTestCase( ) VirtualMachine.objects.bulk_create(virtualmachines) - Interface.objects.bulk_create([ - Interface(virtual_machine=virtualmachines[0], name='Interface 1', type=InterfaceTypeChoices.TYPE_VIRTUAL), - Interface(virtual_machine=virtualmachines[0], name='Interface 2', type=InterfaceTypeChoices.TYPE_VIRTUAL), - Interface(virtual_machine=virtualmachines[0], name='Interface 3', type=InterfaceTypeChoices.TYPE_VIRTUAL), + VMInterface.objects.bulk_create([ + VMInterface(virtual_machine=virtualmachines[0], name='Interface 1'), + VMInterface(virtual_machine=virtualmachines[0], name='Interface 2'), + VMInterface(virtual_machine=virtualmachines[0], name='Interface 3'), ]) vlans = ( @@ -226,49 +222,47 @@ class InterfaceTestCase( ) VLAN.objects.bulk_create(vlans) + tags = cls.create_tags('Alpha', 'Bravo', 'Charlie') + cls.form_data = { 'virtual_machine': virtualmachines[1].pk, 'name': 'Interface X', - 'type': InterfaceTypeChoices.TYPE_VIRTUAL, 'enabled': False, - 'mgmt_only': False, 'mac_address': EUI('01-02-03-04-05-06'), 'mtu': 2000, 'description': 'New description', 'mode': InterfaceModeChoices.MODE_TAGGED, 'untagged_vlan': vlans[0].pk, 'tagged_vlans': [v.pk for v in vlans[1:4]], - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } cls.bulk_create_data = { 'virtual_machine': virtualmachines[1].pk, 'name_pattern': 'Interface [4-6]', - 'type': InterfaceTypeChoices.TYPE_VIRTUAL, 'enabled': False, - 'mgmt_only': False, 'mac_address': EUI('01-02-03-04-05-06'), 'mtu': 2000, 'description': 'New description', 'mode': InterfaceModeChoices.MODE_TAGGED, 'untagged_vlan': vlans[0].pk, 'tagged_vlans': [v.pk for v in vlans[1:4]], - 'tags': 'Alpha,Bravo,Charlie', + 'tags': [t.pk for t in tags], } + cls.csv_data = ( + "virtual_machine,name", + "Virtual Machine 2,Interface 4", + "Virtual Machine 2,Interface 5", + "Virtual Machine 2,Interface 6", + ) + cls.bulk_edit_data = { 'virtual_machine': virtualmachines[1].pk, 'enabled': False, 'mtu': 2000, 'description': 'New description', 'mode': InterfaceModeChoices.MODE_TAGGED, - # 'untagged_vlan': vlans[0].pk, - # 'tagged_vlans': [v.pk for v in vlans[1:4]], + 'untagged_vlan': vlans[0].pk, + 'tagged_vlans': [v.pk for v in vlans[1:4]], } - - cls.csv_data = ( - "device,name,type", - "Device 1,Interface 4,1000BASE-T (1GE)", - "Device 1,Interface 5,1000BASE-T (1GE)", - "Device 1,Interface 6,1000BASE-T (1GE)", - ) diff --git a/netbox/virtualization/urls.py b/netbox/virtualization/urls.py index 557f8a9ca..34172ee88 100644 --- a/netbox/virtualization/urls.py +++ b/netbox/virtualization/urls.py @@ -1,16 +1,16 @@ from django.urls import path from extras.views import ObjectChangeLogView -from ipam.views import ServiceCreateView +from ipam.views import ServiceEditView from . import views -from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface app_name = 'virtualization' urlpatterns = [ # Cluster types path('cluster-types/', views.ClusterTypeListView.as_view(), name='clustertype_list'), - path('cluster-types/add/', views.ClusterTypeCreateView.as_view(), name='clustertype_add'), + path('cluster-types/add/', views.ClusterTypeEditView.as_view(), name='clustertype_add'), path('cluster-types/import/', views.ClusterTypeBulkImportView.as_view(), name='clustertype_import'), path('cluster-types/delete/', views.ClusterTypeBulkDeleteView.as_view(), name='clustertype_bulk_delete'), path('cluster-types//edit/', views.ClusterTypeEditView.as_view(), name='clustertype_edit'), @@ -18,7 +18,7 @@ urlpatterns = [ # Cluster groups path('cluster-groups/', views.ClusterGroupListView.as_view(), name='clustergroup_list'), - path('cluster-groups/add/', views.ClusterGroupCreateView.as_view(), name='clustergroup_add'), + path('cluster-groups/add/', views.ClusterGroupEditView.as_view(), name='clustergroup_add'), path('cluster-groups/import/', views.ClusterGroupBulkImportView.as_view(), name='clustergroup_import'), path('cluster-groups/delete/', views.ClusterGroupBulkDeleteView.as_view(), name='clustergroup_bulk_delete'), path('cluster-groups//edit/', views.ClusterGroupEditView.as_view(), name='clustergroup_edit'), @@ -26,7 +26,7 @@ urlpatterns = [ # Clusters path('clusters/', views.ClusterListView.as_view(), name='cluster_list'), - path('clusters/add/', views.ClusterCreateView.as_view(), name='cluster_add'), + path('clusters/add/', views.ClusterEditView.as_view(), name='cluster_add'), path('clusters/import/', views.ClusterBulkImportView.as_view(), name='cluster_import'), path('clusters/edit/', views.ClusterBulkEditView.as_view(), name='cluster_bulk_edit'), path('clusters/delete/', views.ClusterBulkDeleteView.as_view(), name='cluster_bulk_delete'), @@ -39,7 +39,7 @@ urlpatterns = [ # Virtual machines path('virtual-machines/', views.VirtualMachineListView.as_view(), name='virtualmachine_list'), - path('virtual-machines/add/', views.VirtualMachineCreateView.as_view(), name='virtualmachine_add'), + path('virtual-machines/add/', views.VirtualMachineEditView.as_view(), name='virtualmachine_add'), path('virtual-machines/import/', views.VirtualMachineBulkImportView.as_view(), name='virtualmachine_import'), path('virtual-machines/edit/', views.VirtualMachineBulkEditView.as_view(), name='virtualmachine_bulk_edit'), path('virtual-machines/delete/', views.VirtualMachineBulkDeleteView.as_view(), name='virtualmachine_bulk_delete'), @@ -48,14 +48,19 @@ urlpatterns = [ path('virtual-machines//delete/', views.VirtualMachineDeleteView.as_view(), name='virtualmachine_delete'), path('virtual-machines//config-context/', views.VirtualMachineConfigContextView.as_view(), name='virtualmachine_configcontext'), path('virtual-machines//changelog/', ObjectChangeLogView.as_view(), name='virtualmachine_changelog', kwargs={'model': VirtualMachine}), - path('virtual-machines//services/assign/', ServiceCreateView.as_view(), name='virtualmachine_service_assign'), + path('virtual-machines//services/assign/', ServiceEditView.as_view(), name='virtualmachine_service_assign'), # VM interfaces - path('interfaces/add/', views.InterfaceCreateView.as_view(), name='interface_add'), - path('interfaces/edit/', views.InterfaceBulkEditView.as_view(), name='interface_bulk_edit'), - path('interfaces/delete/', views.InterfaceBulkDeleteView.as_view(), name='interface_bulk_delete'), - path('interfaces//edit/', views.InterfaceEditView.as_view(), name='interface_edit'), - path('interfaces//delete/', views.InterfaceDeleteView.as_view(), name='interface_delete'), - path('virtual-machines/interfaces/add/', views.VirtualMachineBulkAddInterfaceView.as_view(), name='virtualmachine_bulk_add_interface'), + path('interfaces/', views.VMInterfaceListView.as_view(), name='vminterface_list'), + path('interfaces/add/', views.VMInterfaceCreateView.as_view(), name='vminterface_add'), + path('interfaces/import/', views.VMInterfaceBulkImportView.as_view(), name='vminterface_import'), + path('interfaces/edit/', views.VMInterfaceBulkEditView.as_view(), name='vminterface_bulk_edit'), + path('interfaces/rename/', views.VMInterfaceBulkRenameView.as_view(), name='vminterface_bulk_rename'), + path('interfaces/delete/', views.VMInterfaceBulkDeleteView.as_view(), name='vminterface_bulk_delete'), + path('interfaces//', views.VMInterfaceView.as_view(), name='vminterface'), + path('interfaces//edit/', views.VMInterfaceEditView.as_view(), name='vminterface_edit'), + path('interfaces//delete/', views.VMInterfaceDeleteView.as_view(), name='vminterface_delete'), + path('interfaces//changelog/', ObjectChangeLogView.as_view(), name='vminterface_changelog', kwargs={'model': VMInterface}), + path('virtual-machines/interfaces/add/', views.VirtualMachineBulkAddInterfaceView.as_view(), name='virtualmachine_bulk_add_vminterface'), ] diff --git a/netbox/virtualization/views.py b/netbox/virtualization/views.py index 68a2443ae..60b5f766a 100644 --- a/netbox/virtualization/views.py +++ b/netbox/virtualization/views.py @@ -1,53 +1,45 @@ from django.contrib import messages -from django.contrib.auth.mixins import PermissionRequiredMixin from django.db import transaction from django.db.models import Count from django.shortcuts import get_object_or_404, redirect, render from django.urls import reverse -from django.views.generic import View -from dcim.models import Device, Interface +from dcim.models import Device from dcim.tables import DeviceTable from extras.views import ObjectConfigContextView from ipam.models import Service +from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable from utilities.views import ( - BulkComponentCreateView, BulkDeleteView, BulkEditView, BulkImportView, ComponentCreateView, ObjectDeleteView, - ObjectEditView, ObjectListView, + BulkComponentCreateView, BulkDeleteView, BulkEditView, BulkImportView, BulkRenameView, ComponentCreateView, + ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView, ) from . import filters, forms, tables -from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine +from .models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface # # Cluster types # -class ClusterTypeListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'virtualization.view_clustertype' +class ClusterTypeListView(ObjectListView): queryset = ClusterType.objects.annotate(cluster_count=Count('clusters')) table = tables.ClusterTypeTable -class ClusterTypeCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'virtualization.add_clustertype' +class ClusterTypeEditView(ObjectEditView): queryset = ClusterType.objects.all() model_form = forms.ClusterTypeForm default_return_url = 'virtualization:clustertype_list' -class ClusterTypeEditView(ClusterTypeCreateView): - permission_required = 'virtualization.change_clustertype' - - -class ClusterTypeBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'virtualization.add_clustertype' +class ClusterTypeBulkImportView(BulkImportView): + queryset = ClusterType.objects.all() model_form = forms.ClusterTypeCSVForm table = tables.ClusterTypeTable default_return_url = 'virtualization:clustertype_list' -class ClusterTypeBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'virtualization.delete_clustertype' +class ClusterTypeBulkDeleteView(BulkDeleteView): queryset = ClusterType.objects.annotate(cluster_count=Count('clusters')) table = tables.ClusterTypeTable default_return_url = 'virtualization:clustertype_list' @@ -57,32 +49,25 @@ class ClusterTypeBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Cluster groups # -class ClusterGroupListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'virtualization.view_clustergroup' +class ClusterGroupListView(ObjectListView): queryset = ClusterGroup.objects.annotate(cluster_count=Count('clusters')) table = tables.ClusterGroupTable -class ClusterGroupCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'virtualization.add_clustergroup' +class ClusterGroupEditView(ObjectEditView): queryset = ClusterGroup.objects.all() model_form = forms.ClusterGroupForm default_return_url = 'virtualization:clustergroup_list' -class ClusterGroupEditView(ClusterGroupCreateView): - permission_required = 'virtualization.change_clustergroup' - - -class ClusterGroupBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'virtualization.add_clustergroup' +class ClusterGroupBulkImportView(BulkImportView): + queryset = ClusterGroup.objects.all() model_form = forms.ClusterGroupCSVForm table = tables.ClusterGroupTable default_return_url = 'virtualization:clustergroup_list' -class ClusterGroupBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'virtualization.delete_clustergroup' +class ClusterGroupBulkDeleteView(BulkDeleteView): queryset = ClusterGroup.objects.annotate(cluster_count=Count('clusters')) table = tables.ClusterGroupTable default_return_url = 'virtualization:clustergroup_list' @@ -92,21 +77,20 @@ class ClusterGroupBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # Clusters # -class ClusterListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'virtualization.view_cluster' +class ClusterListView(ObjectListView): queryset = Cluster.objects.prefetch_related('type', 'group', 'site', 'tenant') table = tables.ClusterTable filterset = filters.ClusterFilterSet filterset_form = forms.ClusterFilterForm -class ClusterView(PermissionRequiredMixin, View): - permission_required = 'virtualization.view_cluster' +class ClusterView(ObjectView): + queryset = Cluster.objects.all() def get(self, request, pk): - cluster = get_object_or_404(Cluster, pk=pk) - devices = Device.objects.filter(cluster=cluster).prefetch_related( + cluster = get_object_or_404(self.queryset, pk=pk) + devices = Device.objects.restrict(request.user, 'view').filter(cluster=cluster).prefetch_related( 'site', 'rack', 'tenant', 'device_type__manufacturer' ) device_table = DeviceTable(list(devices), orderable=False) @@ -119,32 +103,25 @@ class ClusterView(PermissionRequiredMixin, View): }) -class ClusterCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'virtualization.add_cluster' +class ClusterEditView(ObjectEditView): template_name = 'virtualization/cluster_edit.html' queryset = Cluster.objects.all() model_form = forms.ClusterForm -class ClusterEditView(ClusterCreateView): - permission_required = 'virtualization.change_cluster' - - -class ClusterDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'virtualization.delete_cluster' +class ClusterDeleteView(ObjectDeleteView): queryset = Cluster.objects.all() default_return_url = 'virtualization:cluster_list' -class ClusterBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'virtualization.add_cluster' +class ClusterBulkImportView(BulkImportView): + queryset = Cluster.objects.all() model_form = forms.ClusterCSVForm table = tables.ClusterTable default_return_url = 'virtualization:cluster_list' -class ClusterBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'virtualization.change_cluster' +class ClusterBulkEditView(BulkEditView): queryset = Cluster.objects.prefetch_related('type', 'group', 'site') filterset = filters.ClusterFilterSet table = tables.ClusterTable @@ -152,22 +129,20 @@ class ClusterBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'virtualization:cluster_list' -class ClusterBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'virtualization.delete_cluster' +class ClusterBulkDeleteView(BulkDeleteView): queryset = Cluster.objects.prefetch_related('type', 'group', 'site') filterset = filters.ClusterFilterSet table = tables.ClusterTable default_return_url = 'virtualization:cluster_list' -class ClusterAddDevicesView(PermissionRequiredMixin, View): - permission_required = 'virtualization.change_cluster' +class ClusterAddDevicesView(ObjectEditView): + queryset = Cluster.objects.all() form = forms.ClusterAddDevicesForm template_name = 'virtualization/cluster_add_devices.html' def get(self, request, pk): - - cluster = get_object_or_404(Cluster, pk=pk) + cluster = get_object_or_404(self.queryset, pk=pk) form = self.form(cluster, initial=request.GET) return render(request, self.template_name, { @@ -177,8 +152,7 @@ class ClusterAddDevicesView(PermissionRequiredMixin, View): }) def post(self, request, pk): - - cluster = get_object_or_404(Cluster, pk=pk) + cluster = get_object_or_404(self.queryset, pk=pk) form = self.form(cluster, request.POST) if form.is_valid(): @@ -203,14 +177,14 @@ class ClusterAddDevicesView(PermissionRequiredMixin, View): }) -class ClusterRemoveDevicesView(PermissionRequiredMixin, View): - permission_required = 'virtualization.change_cluster' +class ClusterRemoveDevicesView(ObjectEditView): + queryset = Cluster.objects.all() form = forms.ClusterRemoveDevicesForm template_name = 'utilities/obj_bulk_remove.html' def post(self, request, pk): - cluster = get_object_or_404(Cluster, pk=pk) + cluster = get_object_or_404(self.queryset, pk=pk) if '_confirm' in request.POST: form = self.form(request.POST) @@ -248,8 +222,7 @@ class ClusterRemoveDevicesView(PermissionRequiredMixin, View): # Virtual machines # -class VirtualMachineListView(PermissionRequiredMixin, ObjectListView): - permission_required = 'virtualization.view_virtualmachine' +class VirtualMachineListView(ObjectListView): queryset = VirtualMachine.objects.prefetch_related('cluster', 'tenant', 'role', 'primary_ip4', 'primary_ip6') filterset = filters.VirtualMachineFilterSet filterset_form = forms.VirtualMachineFilterForm @@ -257,14 +230,14 @@ class VirtualMachineListView(PermissionRequiredMixin, ObjectListView): template_name = 'virtualization/virtualmachine_list.html' -class VirtualMachineView(PermissionRequiredMixin, View): - permission_required = 'virtualization.view_virtualmachine' +class VirtualMachineView(ObjectView): + queryset = VirtualMachine.objects.prefetch_related('tenant__group') def get(self, request, pk): - virtualmachine = get_object_or_404(VirtualMachine.objects.prefetch_related('tenant__group'), pk=pk) - interfaces = Interface.objects.filter(virtual_machine=virtualmachine) - services = Service.objects.filter(virtual_machine=virtualmachine) + virtualmachine = get_object_or_404(self.queryset, pk=pk) + interfaces = VMInterface.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine) + services = Service.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine) return render(request, 'virtualization/virtualmachine.html', { 'virtualmachine': virtualmachine, @@ -273,39 +246,31 @@ class VirtualMachineView(PermissionRequiredMixin, View): }) -class VirtualMachineConfigContextView(PermissionRequiredMixin, ObjectConfigContextView): - permission_required = 'virtualization.view_virtualmachine' - object_class = VirtualMachine +class VirtualMachineConfigContextView(ObjectConfigContextView): + queryset = VirtualMachine.objects.all() base_template = 'virtualization/virtualmachine.html' -class VirtualMachineCreateView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'virtualization.add_virtualmachine' +class VirtualMachineEditView(ObjectEditView): queryset = VirtualMachine.objects.all() model_form = forms.VirtualMachineForm template_name = 'virtualization/virtualmachine_edit.html' default_return_url = 'virtualization:virtualmachine_list' -class VirtualMachineEditView(VirtualMachineCreateView): - permission_required = 'virtualization.change_virtualmachine' - - -class VirtualMachineDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'virtualization.delete_virtualmachine' +class VirtualMachineDeleteView(ObjectDeleteView): queryset = VirtualMachine.objects.all() default_return_url = 'virtualization:virtualmachine_list' -class VirtualMachineBulkImportView(PermissionRequiredMixin, BulkImportView): - permission_required = 'virtualization.add_virtualmachine' +class VirtualMachineBulkImportView(BulkImportView): + queryset = VirtualMachine.objects.all() model_form = forms.VirtualMachineCSVForm table = tables.VirtualMachineTable default_return_url = 'virtualization:virtualmachine_list' -class VirtualMachineBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'virtualization.change_virtualmachine' +class VirtualMachineBulkEditView(BulkEditView): queryset = VirtualMachine.objects.prefetch_related('cluster', 'tenant', 'role') filterset = filters.VirtualMachineFilterSet table = tables.VirtualMachineTable @@ -313,8 +278,7 @@ class VirtualMachineBulkEditView(PermissionRequiredMixin, BulkEditView): default_return_url = 'virtualization:virtualmachine_list' -class VirtualMachineBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'virtualization.delete_virtualmachine' +class VirtualMachineBulkDeleteView(BulkDeleteView): queryset = VirtualMachine.objects.prefetch_related('cluster', 'tenant', 'role') filterset = filters.VirtualMachineFilterSet table = tables.VirtualMachineTable @@ -325,50 +289,99 @@ class VirtualMachineBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): # VM interfaces # -class InterfaceCreateView(PermissionRequiredMixin, ComponentCreateView): - permission_required = 'dcim.add_interface' - model = Interface - form = forms.InterfaceCreateForm - model_form = forms.InterfaceForm +class VMInterfaceListView(ObjectListView): + queryset = VMInterface.objects.prefetch_related('virtual_machine') + filterset = filters.VMInterfaceFilterSet + filterset_form = forms.VMInterfaceFilterForm + table = tables.VMInterfaceTable + action_buttons = ('export',) + + +class VMInterfaceView(ObjectView): + queryset = VMInterface.objects.all() + + def get(self, request, pk): + + vminterface = get_object_or_404(self.queryset, pk=pk) + + # Get assigned IP addresses + ipaddress_table = InterfaceIPAddressTable( + data=vminterface.ip_addresses.restrict(request.user, 'view').prefetch_related('vrf', 'tenant'), + orderable=False + ) + + # Get assigned VLANs and annotate whether each is tagged or untagged + vlans = [] + if vminterface.untagged_vlan is not None: + vlans.append(vminterface.untagged_vlan) + vlans[0].tagged = False + for vlan in vminterface.tagged_vlans.prefetch_related('site', 'group', 'tenant', 'role'): + vlan.tagged = True + vlans.append(vlan) + vlan_table = InterfaceVLANTable( + interface=vminterface, + data=vlans, + orderable=False + ) + + return render(request, 'virtualization/vminterface.html', { + 'vminterface': vminterface, + 'ipaddress_table': ipaddress_table, + 'vlan_table': vlan_table, + }) + + +# TODO: This should not use ComponentCreateView +class VMInterfaceCreateView(ComponentCreateView): + queryset = VMInterface.objects.all() + form = forms.VMInterfaceCreateForm + model_form = forms.VMInterfaceForm template_name = 'virtualization/virtualmachine_component_add.html' -class InterfaceEditView(PermissionRequiredMixin, ObjectEditView): - permission_required = 'dcim.change_interface' - queryset = Interface.objects.all() - model_form = forms.InterfaceForm - template_name = 'virtualization/interface_edit.html' +class VMInterfaceEditView(ObjectEditView): + queryset = VMInterface.objects.all() + model_form = forms.VMInterfaceForm + template_name = 'virtualization/vminterface_edit.html' -class InterfaceDeleteView(PermissionRequiredMixin, ObjectDeleteView): - permission_required = 'dcim.delete_interface' - queryset = Interface.objects.all() +class VMInterfaceDeleteView(ObjectDeleteView): + queryset = VMInterface.objects.all() -class InterfaceBulkEditView(PermissionRequiredMixin, BulkEditView): - permission_required = 'dcim.change_interface' - queryset = Interface.objects.all() - table = tables.InterfaceTable - form = forms.InterfaceBulkEditForm +class VMInterfaceBulkImportView(BulkImportView): + queryset = VMInterface.objects.all() + model_form = forms.VMInterfaceCSVForm + table = tables.VMInterfaceTable + default_return_url = 'virtualization:vminterface_list' -class InterfaceBulkDeleteView(PermissionRequiredMixin, BulkDeleteView): - permission_required = 'dcim.delete_interface' - queryset = Interface.objects.all() - table = tables.InterfaceTable +class VMInterfaceBulkEditView(BulkEditView): + queryset = VMInterface.objects.all() + table = tables.VMInterfaceTable + form = forms.VMInterfaceBulkEditForm + + +class VMInterfaceBulkRenameView(BulkRenameView): + queryset = VMInterface.objects.all() + form = forms.VMInterfaceBulkRenameForm + + +class VMInterfaceBulkDeleteView(BulkDeleteView): + queryset = VMInterface.objects.all() + table = tables.VMInterfaceTable # # Bulk Device component creation # -class VirtualMachineBulkAddInterfaceView(PermissionRequiredMixin, BulkComponentCreateView): - permission_required = 'dcim.add_interface' +class VirtualMachineBulkAddInterfaceView(BulkComponentCreateView): parent_model = VirtualMachine parent_field = 'virtual_machine' - form = forms.InterfaceBulkCreateForm - model = Interface - model_form = forms.InterfaceForm + form = forms.VMInterfaceBulkCreateForm + queryset = VMInterface.objects.all() + model_form = forms.VMInterfaceForm filterset = filters.VirtualMachineFilterSet table = tables.VirtualMachineTable default_return_url = 'virtualization:virtualmachine_list' diff --git a/requirements.txt b/requirements.txt index c9f51cff0..eac5ca9d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,10 +6,9 @@ django-filter==2.2.0 django-mptt==0.11.0 django-pglocks==1.0.4 django-prometheus==2.0.0 -django-rq==2.3.1 +django-rq==2.3.2 django-tables2==2.3.1 django-taggit==1.2.0 -django-taggit-serializer==0.1.7 django-timezone-field==4.0 djangorestframework==3.11.0 drf-yasg[validation]==1.17.1