From b75b9e01eb588447e2c45e24ab8a60a45bab9585 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 19 Mar 2024 10:04:33 -0700 Subject: [PATCH] 9856 review changes --- base_requirements.txt | 4 + netbox/circuits/graphql/schema.py | 12 +- netbox/core/graphql/schema.py | 4 +- netbox/core/graphql/types.py | 2 +- netbox/dcim/graphql/schema.py | 54 +++--- netbox/extras/graphql/schema.py | 24 +-- netbox/ipam/graphql/schema.py | 32 ++-- netbox/netbox/graphql/filter_mixins.py | 236 ++++++++++++------------ netbox/tenancy/graphql/schema.py | 12 +- netbox/tenancy/graphql/types.py | 9 +- netbox/users/graphql/schema.py | 4 +- netbox/users/graphql/types.py | 8 +- netbox/utilities/testing/api.py | 8 +- netbox/virtualization/graphql/schema.py | 12 +- netbox/vpn/graphql/schema.py | 20 +- netbox/wireless/graphql/schema.py | 6 +- requirements.txt | 3 +- 17 files changed, 224 insertions(+), 226 deletions(-) diff --git a/base_requirements.txt b/base_requirements.txt index 1d6a2e97d..30c3e3302 100644 --- a/base_requirements.txt +++ b/base_requirements.txt @@ -131,6 +131,10 @@ social-auth-core # https://github.com/python-social-auth/social-app-django/blob/master/CHANGELOG.md social-auth-app-django +# Strawberry GraphQL +# https://github.com/strawberry-graphql/strawberry/blob/main/CHANGELOG.md +strawberry-graphql + # Strawberry GraphQL Django extension # https://github.com/strawberry-graphql/strawberry-django/blob/main/CHANGELOG.md strawberry-django diff --git a/netbox/circuits/graphql/schema.py b/netbox/circuits/graphql/schema.py index 9aa7a0228..ac8626cc5 100644 --- a/netbox/circuits/graphql/schema.py +++ b/netbox/circuits/graphql/schema.py @@ -11,30 +11,30 @@ from .types import * class CircuitsQuery: @strawberry.field def circuit(self, id: int) -> CircuitType: - return models.Circuit.objects.get(id=id) + return models.Circuit.objects.get(pk=id) circuit_list: List[CircuitType] = strawberry_django.field() @strawberry.field def circuit_termination(self, id: int) -> CircuitTerminationType: - return models.CircuitTermination.objects.get(id=id) + return models.CircuitTermination.objects.get(pk=id) circuit_termination_list: List[CircuitTerminationType] = strawberry_django.field() @strawberry.field def circuit_type(self, id: int) -> CircuitTypeType: - return models.CircuitType.objects.get(id=id) + return models.CircuitType.objects.get(pk=id) circuit_type_list: List[CircuitTypeType] = strawberry_django.field() @strawberry.field def provider(self, id: int) -> ProviderType: - return models.Provider.objects.get(id=id) + return models.Provider.objects.get(pk=id) provider_list: List[ProviderType] = strawberry_django.field() @strawberry.field def provider_account(self, id: int) -> ProviderAccountType: - return models.ProviderAccount.objects.get(id=id) + return models.ProviderAccount.objects.get(pk=id) provider_account_list: List[ProviderAccountType] = strawberry_django.field() @strawberry.field def provider_network(self, id: int) -> ProviderNetworkType: - return models.ProviderNetwork.objects.get(id=id) + return models.ProviderNetwork.objects.get(pk=id) provider_network_list: List[ProviderNetworkType] = strawberry_django.field() diff --git a/netbox/core/graphql/schema.py b/netbox/core/graphql/schema.py index 64ed87985..34135cd47 100644 --- a/netbox/core/graphql/schema.py +++ b/netbox/core/graphql/schema.py @@ -11,10 +11,10 @@ from .types import * class CoreQuery: @strawberry.field def data_file(self, id: int) -> DataFileType: - return models.DataFile.objects.get(id=id) + return models.DataFile.objects.get(pk=id) data_file_list: List[DataFileType] = strawberry_django.field() @strawberry.field def data_source(self, id: int) -> DataSourceType: - return models.DataSource.objects.get(id=id) + return models.DataSource.objects.get(pk=id) data_source_list: List[DataSourceType] = strawberry_django.field() diff --git a/netbox/core/graphql/types.py b/netbox/core/graphql/types.py index 013dcd416..676c2aeec 100644 --- a/netbox/core/graphql/types.py +++ b/netbox/core/graphql/types.py @@ -15,7 +15,7 @@ __all__ = ( @strawberry_django.type( models.DataFile, - exclude=('data',), + exclude=['data',], filters=DataFileFilter ) class DataFileType(BaseObjectType): diff --git a/netbox/dcim/graphql/schema.py b/netbox/dcim/graphql/schema.py index c8c0ee777..c3962a87a 100644 --- a/netbox/dcim/graphql/schema.py +++ b/netbox/dcim/graphql/schema.py @@ -11,137 +11,137 @@ from .types import * class DCIMQuery: @strawberry.field def cable(self, id: int) -> CableType: - return models.Cable.objects.get(id=id) + return models.Cable.objects.get(pk=id) cable_list: List[CableType] = strawberry_django.field() @strawberry.field def console_port(self, id: int) -> ConsolePortType: - return models.ConsolePort.objects.get(id=id) + return models.ConsolePort.objects.get(pk=id) console_port_list: List[ConsolePortType] = strawberry_django.field() @strawberry.field def console_port_template(self, id: int) -> ConsolePortTemplateType: - return models.ConsolePortTemplate.objects.get(id=id) + return models.ConsolePortTemplate.objects.get(pk=id) console_port_template_list: List[ConsolePortTemplateType] = strawberry_django.field() @strawberry.field def console_server_port(self, id: int) -> ConsoleServerPortType: - return models.ConsoleServerPort.objects.get(id=id) + return models.ConsoleServerPort.objects.get(pk=id) console_server_port_list: List[ConsoleServerPortType] = strawberry_django.field() @strawberry.field def console_server_port_template(self, id: int) -> ConsoleServerPortTemplateType: - return models.ConsoleServerPortTemplate.objects.get(id=id) + return models.ConsoleServerPortTemplate.objects.get(pk=id) console_server_port_template_list: List[ConsoleServerPortTemplateType] = strawberry_django.field() @strawberry.field def device(self, id: int) -> DeviceType: - return models.Device.objects.get(id=id) + return models.Device.objects.get(pk=id) device_list: List[DeviceType] = strawberry_django.field() @strawberry.field def device_bay(self, id: int) -> DeviceBayType: - return models.DeviceBay.objects.get(id=id) + return models.DeviceBay.objects.get(pk=id) device_bay_list: List[DeviceBayType] = strawberry_django.field() @strawberry.field def device_bay_template(self, id: int) -> DeviceBayTemplateType: - return models.DeviceBayTemplate.objects.get(id=id) + return models.DeviceBayTemplate.objects.get(pk=id) device_bay_template_list: List[DeviceBayTemplateType] = strawberry_django.field() @strawberry.field def device_role(self, id: int) -> DeviceRoleType: - return models.DeviceRole.objects.get(id=id) + return models.DeviceRole.objects.get(pk=id) device_role_list: List[DeviceRoleType] = strawberry_django.field() @strawberry.field def device_type(self, id: int) -> DeviceTypeType: - return models.DeviceType.objects.get(id=id) + return models.DeviceType.objects.get(pk=id) device_type_list: List[DeviceTypeType] = strawberry_django.field() @strawberry.field def front_port(self, id: int) -> FrontPortType: - return models.FrontPort.objects.get(id=id) + return models.FrontPort.objects.get(pk=id) front_port_list: List[FrontPortType] = strawberry_django.field() @strawberry.field def front_port_template(self, id: int) -> FrontPortTemplateType: - return models.FrontPortTemplate.objects.get(id=id) + return models.FrontPortTemplate.objects.get(pk=id) front_port_template_list: List[FrontPortTemplateType] = strawberry_django.field() @strawberry.field def interface(self, id: int) -> InterfaceType: - return models.Interface.objects.get(id=id) + return models.Interface.objects.get(pk=id) interface_list: List[InterfaceType] = strawberry_django.field() @strawberry.field def interface_template(self, id: int) -> InterfaceTemplateType: - return models.InterfaceTemplate.objects.get(id=id) + return models.InterfaceTemplate.objects.get(pk=id) interface_template_list: List[InterfaceTemplateType] = strawberry_django.field() @strawberry.field def inventory_item(self, id: int) -> InventoryItemType: - return models.InventoryItem.objects.get(id=id) + return models.InventoryItem.objects.get(pk=id) inventory_item_list: List[InventoryItemType] = strawberry_django.field() @strawberry.field def inventory_item_role(self, id: int) -> InventoryItemRoleType: - return models.InventoryItemRole.objects.get(id=id) + return models.InventoryItemRole.objects.get(pk=id) inventory_item_role_list: List[InventoryItemRoleType] = strawberry_django.field() @strawberry.field def inventory_item_template(self, id: int) -> InventoryItemTemplateType: - return models.InventoryItemTemplate.objects.get(id=id) + return models.InventoryItemTemplate.objects.get(pk=id) inventory_item_template_list: List[InventoryItemTemplateType] = strawberry_django.field() @strawberry.field def location(self, id: int) -> LocationType: - return models.Location.objects.get(id=id) + return models.Location.objects.get(pk=id) location_list: List[LocationType] = strawberry_django.field() @strawberry.field def manufacturer(self, id: int) -> ManufacturerType: - return models.Manufacturer.objects.get(id=id) + return models.Manufacturer.objects.get(pk=id) manufacturer_list: List[ManufacturerType] = strawberry_django.field() @strawberry.field def module(self, id: int) -> ModuleType: - return models.Module.objects.get(id=id) + return models.Module.objects.get(pk=id) module_list: List[ModuleType] = strawberry_django.field() @strawberry.field def module_bay(self, id: int) -> ModuleBayType: - return models.ModuleBay.objects.get(id=id) + return models.ModuleBay.objects.get(pk=id) module_bay_list: List[ModuleBayType] = strawberry_django.field() @strawberry.field def module_bay_template(self, id: int) -> ModuleBayTemplateType: - return models.ModuleBayTemplate.objects.get(id=id) + return models.ModuleBayTemplate.objects.get(pk=id) module_bay_template_list: List[ModuleBayTemplateType] = strawberry_django.field() @strawberry.field def module_type(self, id: int) -> ModuleTypeType: - return models.ModuleType.objects.get(id=id) + return models.ModuleType.objects.get(pk=id) module_type_list: List[ModuleTypeType] = strawberry_django.field() @strawberry.field def platform(self, id: int) -> PlatformType: - return models.Platform.objects.get(id=id) + return models.Platform.objects.get(pk=id) platform_list: List[PlatformType] = strawberry_django.field() @strawberry.field def power_feed(self, id: int) -> PowerFeedType: - return models.PowerFeed.objects.get(id=id) + return models.PowerFeed.objects.get(pk=id) power_feed_list: List[PowerFeedType] = strawberry_django.field() @strawberry.field def power_outlet(self, id: int) -> PowerOutletType: - return models.PowerOutlet.objects.get(id=id) + return models.PowerOutlet.objects.get(pk=id) power_outlet_list: List[PowerOutletType] = strawberry_django.field() @strawberry.field def power_outlet_template(self, id: int) -> PowerOutletTemplateType: - return models.PowerOutletTemplate.objects.get(id=id) + return models.PowerOutletTemplate.objects.get(pk=id) power_outlet_template_list: List[PowerOutletTemplateType] = strawberry_django.field() @strawberry.field diff --git a/netbox/extras/graphql/schema.py b/netbox/extras/graphql/schema.py index a607882b2..f78285035 100644 --- a/netbox/extras/graphql/schema.py +++ b/netbox/extras/graphql/schema.py @@ -11,60 +11,60 @@ from .types import * class ExtrasQuery: @strawberry.field def config_context(self, id: int) -> ConfigContextType: - return models.ConfigContext.objects.get(id=id) + return models.ConfigContext.objects.get(pk=id) config_context_list: List[ConfigContextType] = strawberry_django.field() @strawberry.field def config_template(self, id: int) -> ConfigTemplateType: - return models.ConfigTemplate.objects.get(id=id) + return models.ConfigTemplate.objects.get(pk=id) config_template_list: List[ConfigTemplateType] = strawberry_django.field() @strawberry.field def custom_field(self, id: int) -> CustomFieldType: - return models.CustomField.objects.get(id=id) + return models.CustomField.objects.get(pk=id) custom_field_list: List[CustomFieldType] = strawberry_django.field() @strawberry.field def custom_field_choice_set(self, id: int) -> CustomFieldChoiceSetType: - return models.CustomFieldChoiceSet.objects.get(id=id) + return models.CustomFieldChoiceSet.objects.get(pk=id) custom_field_choice_set_list: List[CustomFieldChoiceSetType] = strawberry_django.field() @strawberry.field def custom_link(self, id: int) -> CustomLinkType: - return models.CustomLink.objects.get(id=id) + return models.CustomLink.objects.get(pk=id) custom_link_list: List[CustomLinkType] = strawberry_django.field() @strawberry.field def export_template(self, id: int) -> ExportTemplateType: - return models.ExportTemplate.objects.get(id=id) + return models.ExportTemplate.objects.get(pk=id) export_template_list: List[ExportTemplateType] = strawberry_django.field() @strawberry.field def image_attachment(self, id: int) -> ImageAttachmentType: - return models.ImageAttachment.objects.get(id=id) + return models.ImageAttachment.objects.get(pk=id) image_attachment_list: List[ImageAttachmentType] = strawberry_django.field() @strawberry.field def saved_filter(self, id: int) -> SavedFilterType: - return models.SavedFilter.objects.get(id=id) + return models.SavedFilter.objects.get(pk=id) saved_filter_list: List[SavedFilterType] = strawberry_django.field() @strawberry.field def journal_entry(self, id: int) -> JournalEntryType: - return models.JournalEntry.objects.get(id=id) + return models.JournalEntry.objects.get(pk=id) journal_entry_list: List[JournalEntryType] = strawberry_django.field() @strawberry.field def tag(self, id: int) -> TagType: - return models.Tag.objects.get(id=id) + return models.Tag.objects.get(pk=id) tag_list: List[TagType] = strawberry_django.field() @strawberry.field def webhook(self, id: int) -> WebhookType: - return models.Webhook.objects.get(id=id) + return models.Webhook.objects.get(pk=id) webhook_list: List[WebhookType] = strawberry_django.field() @strawberry.field def event_rule(self, id: int) -> EventRuleType: - return models.EventRule.objects.get(id=id) + return models.EventRule.objects.get(pk=id) event_rule_list: List[EventRuleType] = strawberry_django.field() diff --git a/netbox/ipam/graphql/schema.py b/netbox/ipam/graphql/schema.py index 4a977d07d..c02788c3a 100644 --- a/netbox/ipam/graphql/schema.py +++ b/netbox/ipam/graphql/schema.py @@ -11,80 +11,80 @@ from .types import * class IPAMQuery: @strawberry.field def asn(self, id: int) -> ASNType: - return models.ASN.objects.get(id=id) + return models.ASN.objects.get(pk=id) asn_list: List[ASNType] = strawberry_django.field() @strawberry.field def asn_range(self, id: int) -> ASNRangeType: - return models.ASNRange.objects.get(id=id) + return models.ASNRange.objects.get(pk=id) asn_range_list: List[ASNRangeType] = strawberry_django.field() @strawberry.field def aggregate(self, id: int) -> AggregateType: - return models.Aggregate.objects.get(id=id) + return models.Aggregate.objects.get(pk=id) aggregate_list: List[AggregateType] = strawberry_django.field() @strawberry.field def ip_address(self, id: int) -> IPAddressType: - return models.IPAddress.objects.get(id=id) + return models.IPAddress.objects.get(pk=id) ip_address_list: List[IPAddressType] = strawberry_django.field() @strawberry.field def ip_range(self, id: int) -> IPRangeType: - return models.IPRange.objects.get(id=id) + return models.IPRange.objects.get(pk=id) ip_range_list: List[IPRangeType] = strawberry_django.field() @strawberry.field def prefix(self, id: int) -> PrefixType: - return models.Prefix.objects.get(id=id) + return models.Prefix.objects.get(pk=id) prefix_list: List[PrefixType] = strawberry_django.field() @strawberry.field def rir(self, id: int) -> RIRType: - return models.RIR.objects.get(id=id) + return models.RIR.objects.get(pk=id) rir_list: List[RIRType] = strawberry_django.field() @strawberry.field def role(self, id: int) -> RoleType: - return models.Role.objects.get(id=id) + return models.Role.objects.get(pk=id) role_list: List[RoleType] = strawberry_django.field() @strawberry.field def route_target(self, id: int) -> RouteTargetType: - return models.RouteTarget.objects.get(id=id) + return models.RouteTarget.objects.get(pk=id) route_target_list: List[RouteTargetType] = strawberry_django.field() @strawberry.field def service(self, id: int) -> ServiceType: - return models.Service.objects.get(id=id) + return models.Service.objects.get(pk=id) service_list: List[ServiceType] = strawberry_django.field() @strawberry.field def service_template(self, id: int) -> ServiceTemplateType: - return models.ServiceTemplate.objects.get(id=id) + return models.ServiceTemplate.objects.get(pk=id) service_template_list: List[ServiceTemplateType] = strawberry_django.field() @strawberry.field def fhrp_group(self, id: int) -> FHRPGroupType: - return models.FHRPGroup.objects.get(id=id) + return models.FHRPGroup.objects.get(pk=id) fhrp_group_list: List[FHRPGroupType] = strawberry_django.field() @strawberry.field def fhrp_group_assignment(self, id: int) -> FHRPGroupAssignmentType: - return models.FHRPGroupAssignment.objects.get(id=id) + return models.FHRPGroupAssignment.objects.get(pk=id) fhrp_group_assignment_list: List[FHRPGroupAssignmentType] = strawberry_django.field() @strawberry.field def vlan(self, id: int) -> VLANType: - return models.VLAN.objects.get(id=id) + return models.VLAN.objects.get(pk=id) vlan_list: List[VLANType] = strawberry_django.field() @strawberry.field def vlan_group(self, id: int) -> VLANGroupType: - return models.VLANGroup.objects.get(id=id) + return models.VLANGroup.objects.get(pk=id) vlan_group_list: List[VLANGroupType] = strawberry_django.field() @strawberry.field def vrf(self, id: int) -> VRFType: - return models.VRF.objects.get(id=id) + return models.VRF.objects.get(pk=id) vrf_list: List[VRFType] = strawberry_django.field() diff --git a/netbox/netbox/graphql/filter_mixins.py b/netbox/netbox/graphql/filter_mixins.py index 707e3bfee..363e4fe84 100644 --- a/netbox/netbox/graphql/filter_mixins.py +++ b/netbox/netbox/graphql/filter_mixins.py @@ -11,6 +11,123 @@ from utilities.fields import ColorField, CounterCacheField from utilities.filters import * +def map_strawberry_type(field): + should_create_function = False + attr_type = None + + # NetBox Filter types - put base classes after derived classes + if isinstance(field, ContentTypeFilter): + should_create_function = True + attr_type = str | None + elif isinstance(field, MACAddressFilter): + pass + elif isinstance(field, MultiValueArrayFilter): + pass + elif isinstance(field, MultiValueCharFilter): + should_create_function = True + attr_type = List[str] | None + elif isinstance(field, MultiValueDateFilter): + attr_type = auto + elif isinstance(field, MultiValueDateTimeFilter): + attr_type = auto + elif isinstance(field, MultiValueDecimalFilter): + pass + elif isinstance(field, MultiValueMACAddressFilter): + should_create_function = True + attr_type = List[str] | None + elif isinstance(field, MultiValueNumberFilter): + should_create_function = True + attr_type = List[str] | None + elif isinstance(field, MultiValueTimeFilter): + pass + elif isinstance(field, MultiValueWWNFilter): + should_create_function = True + attr_type = List[str] | None + elif isinstance(field, NullableCharFieldFilter): + pass + elif isinstance(field, NumericArrayFilter): + should_create_function = True + attr_type = int + elif isinstance(field, TreeNodeMultipleChoiceFilter): + should_create_function = True + attr_type = List[str] | None + + # From django_filters - ordering of these matters as base classes must + # come after derived classes so the base class doesn't get matched first + # a pass for the check (no attr_type) means we don't currently handle + # or use that type + elif issubclass(type(field), django_filters.OrderingFilter): + pass + elif issubclass(type(field), django_filters.BaseRangeFilter): + pass + elif issubclass(type(field), django_filters.BaseInFilter): + pass + elif issubclass(type(field), django_filters.LookupChoiceFilter): + pass + elif issubclass(type(field), django_filters.AllValuesMultipleFilter): + pass + elif issubclass(type(field), django_filters.AllValuesFilter): + pass + elif issubclass(type(field), django_filters.TimeRangeFilter): + pass + elif issubclass(type(field), django_filters.IsoDateTimeFromToRangeFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.DateTimeFromToRangeFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.DateFromToRangeFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.DateRangeFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.RangeFilter): + pass + elif issubclass(type(field), django_filters.NumericRangeFilter): + pass + elif issubclass(type(field), django_filters.NumberFilter): + should_create_function = True + attr_type = int + elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter): + should_create_function = True + attr_type = List[str] | None + elif issubclass(type(field), django_filters.ModelChoiceFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.DurationFilter): + pass + elif issubclass(type(field), django_filters.IsoDateTimeFilter): + pass + elif issubclass(type(field), django_filters.DateTimeFilter): + attr_type = auto + elif issubclass(type(field), django_filters.TimeFilter): + attr_type = auto + elif issubclass(type(field), django_filters.DateFilter): + attr_type = auto + elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter): + pass + elif issubclass(type(field), django_filters.MultipleChoiceFilter): + should_create_function = True + attr_type = List[str] | None + elif issubclass(type(field), django_filters.TypedChoiceFilter): + pass + elif issubclass(type(field), django_filters.ChoiceFilter): + pass + elif issubclass(type(field), django_filters.BooleanFilter): + should_create_function = True + attr_type = bool | None + elif issubclass(type(field), django_filters.UUIDFilter): + should_create_function = True + attr_type = str | None + elif issubclass(type(field), django_filters.CharFilter): + # looks like only used by 'q' + should_create_function = True + attr_type = str | None + + return should_create_function, attr_type + + def autotype_decorator(filterset): """ Decorator used to auto creates a dataclass used by Strawberry based on a filterset. @@ -36,10 +153,10 @@ def autotype_decorator(filterset): if fieldname not in cls.__annotations__ and attr_type: cls.__annotations__[fieldname] = attr_type - fname = f"filter_{fieldname}" - if should_create_function and not hasattr(cls, fname): + filter_name = f"filter_{fieldname}" + if should_create_function and not hasattr(cls, filter_name): filter_by_filterset = getattr(cls, 'filter_by_filterset') - setattr(cls, fname, partialmethod(filter_by_filterset, key=fieldname)) + setattr(cls, filter_name, partialmethod(filter_by_filterset, key=fieldname)) def wrapper(cls): cls.filterset = filterset @@ -64,119 +181,8 @@ def autotype_decorator(filterset): declared_filters = filterset.declared_filters for fieldname, field in declared_filters.items(): - should_create_function = False - attr_type = None - - # NetBox Filter types - put base classes after derived classes - if isinstance(field, ContentTypeFilter): - should_create_function = True - attr_type = str | None - elif isinstance(field, MACAddressFilter): - pass - elif isinstance(field, MultiValueArrayFilter): - pass - elif isinstance(field, MultiValueCharFilter): - should_create_function = True - attr_type = List[str] | None - elif isinstance(field, MultiValueDateFilter): - attr_type = auto - elif isinstance(field, MultiValueDateTimeFilter): - attr_type = auto - elif isinstance(field, MultiValueDecimalFilter): - pass - elif isinstance(field, MultiValueMACAddressFilter): - should_create_function = True - attr_type = List[str] | None - elif isinstance(field, MultiValueNumberFilter): - should_create_function = True - attr_type = List[str] | None - elif isinstance(field, MultiValueTimeFilter): - pass - elif isinstance(field, MultiValueWWNFilter): - should_create_function = True - attr_type = List[str] | None - elif isinstance(field, NullableCharFieldFilter): - pass - elif isinstance(field, NumericArrayFilter): - should_create_function = True - attr_type = int - elif isinstance(field, TreeNodeMultipleChoiceFilter): - should_create_function = True - attr_type = List[str] | None - - # From django_filters - ordering of these matters as base classes must - # come after derived classes so the base class doesn't get matched first - # a pass for the check (no attr_type) means we don't currently handle - # or use that type - elif issubclass(type(field), django_filters.OrderingFilter): - pass - elif issubclass(type(field), django_filters.BaseRangeFilter): - pass - elif issubclass(type(field), django_filters.BaseInFilter): - pass - elif issubclass(type(field), django_filters.LookupChoiceFilter): - pass - elif issubclass(type(field), django_filters.AllValuesMultipleFilter): - pass - elif issubclass(type(field), django_filters.AllValuesFilter): - pass - elif issubclass(type(field), django_filters.TimeRangeFilter): - pass - elif issubclass(type(field), django_filters.IsoDateTimeFromToRangeFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.DateTimeFromToRangeFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.DateFromToRangeFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.DateRangeFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.RangeFilter): - pass - elif issubclass(type(field), django_filters.NumericRangeFilter): - pass - elif issubclass(type(field), django_filters.NumberFilter): - should_create_function = True - attr_type = int - elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter): - should_create_function = True - attr_type = List[str] | None - elif issubclass(type(field), django_filters.ModelChoiceFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.DurationFilter): - pass - elif issubclass(type(field), django_filters.IsoDateTimeFilter): - pass - elif issubclass(type(field), django_filters.DateTimeFilter): - attr_type = auto - elif issubclass(type(field), django_filters.TimeFilter): - attr_type = auto - elif issubclass(type(field), django_filters.DateFilter): - attr_type = auto - elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter): - pass - elif issubclass(type(field), django_filters.MultipleChoiceFilter): - should_create_function = True - attr_type = List[str] | None - elif issubclass(type(field), django_filters.TypedChoiceFilter): - pass - elif issubclass(type(field), django_filters.ChoiceFilter): - pass - elif issubclass(type(field), django_filters.BooleanFilter): - should_create_function = True - attr_type = bool | None - elif issubclass(type(field), django_filters.UUIDFilter): - should_create_function = True - attr_type = str | None - elif issubclass(type(field), django_filters.CharFilter): - # looks like only used by 'q' - should_create_function = True - attr_type = str | None + should_create_function, attr_type = map_strawberry_type(field) if attr_type is None: raise NotImplementedError(f"GraphQL Filter field unknown: {fieldname}: {field}") diff --git a/netbox/tenancy/graphql/schema.py b/netbox/tenancy/graphql/schema.py index f33c4d6c1..79f8660d4 100644 --- a/netbox/tenancy/graphql/schema.py +++ b/netbox/tenancy/graphql/schema.py @@ -11,30 +11,30 @@ from .types import * class TenancyQuery: @strawberry.field def tenant(self, id: int) -> TenantType: - return models.Tenant.objects.get(id=id) + return models.Tenant.objects.get(pk=id) tenant_list: List[TenantType] = strawberry_django.field() @strawberry.field def tenant_group(self, id: int) -> TenantGroupType: - return models.TenantGroup.objects.get(id=id) + return models.TenantGroup.objects.get(pk=id) tenant_group_list: List[TenantGroupType] = strawberry_django.field() @strawberry.field def contact(self, id: int) -> ContactType: - return models.Contact.objects.get(id=id) + return models.Contact.objects.get(pk=id) contact_list: List[ContactType] = strawberry_django.field() @strawberry.field def contact_role(self, id: int) -> ContactRoleType: - return models.ContactRole.objects.get(id=id) + return models.ContactRole.objects.get(pk=id) contact_role_list: List[ContactRoleType] = strawberry_django.field() @strawberry.field def contact_group(self, id: int) -> ContactGroupType: - return models.ContactGroup.objects.get(id=id) + return models.ContactGroup.objects.get(pk=id) contact_group_list: List[ContactGroupType] = strawberry_django.field() @strawberry.field def contact_assignment(self, id: int) -> ContactAssignmentType: - return models.ContactAssignment.objects.get(id=id) + return models.ContactAssignment.objects.get(pk=id) contact_assignment_list: List[ContactAssignmentType] = strawberry_django.field() diff --git a/netbox/tenancy/graphql/types.py b/netbox/tenancy/graphql/types.py index 8417ad1d5..7c7cd462a 100644 --- a/netbox/tenancy/graphql/types.py +++ b/netbox/tenancy/graphql/types.py @@ -6,6 +6,7 @@ import strawberry_django from extras.graphql.mixins import CustomFieldsMixin, TagsMixin from netbox.graphql.types import BaseObjectType, OrganizationalObjectType, NetBoxObjectType from tenancy import models +from .mixins import ContactAssignmentsMixin from .filters import * __all__ = ( @@ -18,14 +19,6 @@ __all__ = ( ) -@strawberry.type -class ContactAssignmentsMixin: - - @strawberry_django.field - def assignments(self) -> List[Annotated["ContactAssignmentType", strawberry.lazy('tenancy.graphql.types')]]: - return self.assignments.all() - - # # Tenants # diff --git a/netbox/users/graphql/schema.py b/netbox/users/graphql/schema.py index 66a9e8c93..840887ad2 100644 --- a/netbox/users/graphql/schema.py +++ b/netbox/users/graphql/schema.py @@ -12,10 +12,10 @@ from .types import * class UsersQuery: @strawberry.field def group(self, id: int) -> GroupType: - return models.Group.objects.get(id=id) + return models.Group.objects.get(pk=id) group_list: List[GroupType] = strawberry_django.field() @strawberry.field def user(self, id: int) -> UserType: - return models.User.objects.get(id=id) + return models.User.objects.get(pk=id) user_list: List[UserType] = strawberry_django.field() diff --git a/netbox/users/graphql/types.py b/netbox/users/graphql/types.py index 89a5d99da..6fa15bacb 100644 --- a/netbox/users/graphql/types.py +++ b/netbox/users/graphql/types.py @@ -22,9 +22,7 @@ __all__ = ( filters=GroupFilter ) class GroupType: - @classmethod - def get_queryset(cls, queryset, info, **kwargs): - return RestrictedQuerySet(model=Group).restrict(info.context.request.user, 'view') + pass @strawberry_django.type( @@ -36,10 +34,6 @@ class GroupType: filters=UserFilter ) class UserType: - @classmethod - def get_queryset(cls, queryset, info, **kwargs): - return RestrictedQuerySet(model=get_user_model()).restrict(info.context.request.user, 'view') - @strawberry_django.field def groups(self) -> List[GroupType]: return self.groups.all() diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 11007f77f..a30235d93 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -451,12 +451,12 @@ class APIViewTestCases: # Compile list of fields to include fields_string = '' + file_fields = (strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType) for field in type_class.__strawberry_definition__.fields: if ( - field.type in ( - strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType) or - type(field.type) is StrawberryOptional and field.type.of_type in ( - strawberry_django.fields.types.DjangoFileType, strawberry_django.fields.types.DjangoImageType) + field.type in file_fields or ( + type(field.type) is StrawberryOptional and field.type.of_type in file_fields + ) ): # image / file fields nullable or not... fields_string += f'{field.name} {{ name }}\n' diff --git a/netbox/virtualization/graphql/schema.py b/netbox/virtualization/graphql/schema.py index 02dd888d7..72d83155d 100644 --- a/netbox/virtualization/graphql/schema.py +++ b/netbox/virtualization/graphql/schema.py @@ -11,30 +11,30 @@ from .types import * class VirtualizationQuery: @strawberry.field def cluster(self, id: int) -> ClusterType: - return models.Cluster.objects.get(id=id) + return models.Cluster.objects.get(pk=id) cluster_list: List[ClusterType] = strawberry_django.field() @strawberry.field def cluster_group(self, id: int) -> ClusterGroupType: - return models.ClusterGroup.objects.get(id=id) + return models.ClusterGroup.objects.get(pk=id) cluster_group_list: List[ClusterGroupType] = strawberry_django.field() @strawberry.field def cluster_type(self, id: int) -> ClusterTypeType: - return models.ClusterType.objects.get(id=id) + return models.ClusterType.objects.get(pk=id) cluster_type_list: List[ClusterTypeType] = strawberry_django.field() @strawberry.field def virtual_machine(self, id: int) -> VirtualMachineType: - return models.VirtualMachine.objects.get(id=id) + return models.VirtualMachine.objects.get(pk=id) virtual_machine_list: List[VirtualMachineType] = strawberry_django.field() @strawberry.field def vm_interface(self, id: int) -> VMInterfaceType: - return models.VMInterface.objects.get(id=id) + return models.VMInterface.objects.get(pk=id) vm_interface_list: List[VMInterfaceType] = strawberry_django.field() @strawberry.field def virtual_disk(self, id: int) -> VirtualDiskType: - return models.VirtualDisk.objects.get(id=id) + return models.VirtualDisk.objects.get(pk=id) virtual_disk_list: List[VirtualDiskType] = strawberry_django.field() diff --git a/netbox/vpn/graphql/schema.py b/netbox/vpn/graphql/schema.py index 93c6ded77..f37e444a2 100644 --- a/netbox/vpn/graphql/schema.py +++ b/netbox/vpn/graphql/schema.py @@ -11,50 +11,50 @@ from .types import * class VPNQuery: @strawberry.field def ike_policy(self, id: int) -> IKEPolicyType: - return models.IKEPolicy.objects.get(id=id) + return models.IKEPolicy.objects.get(pk=id) ike_policy_list: List[IKEPolicyType] = strawberry_django.field() @strawberry.field def ike_proposal(self, id: int) -> IKEProposalType: - return models.IKEProposal.objects.get(id=id) + return models.IKEProposal.objects.get(pk=id) ike_proposal_list: List[IKEProposalType] = strawberry_django.field() @strawberry.field def ipsec_policy(self, id: int) -> IPSecPolicyType: - return models.IPSecPolicy.objects.get(id=id) + return models.IPSecPolicy.objects.get(pk=id) ipsec_policy_list: List[IPSecPolicyType] = strawberry_django.field() @strawberry.field def ipsec_profile(self, id: int) -> IPSecProfileType: - return models.IPSecProfile.objects.get(id=id) + return models.IPSecProfile.objects.get(pk=id) ipsec_profile_list: List[IPSecProfileType] = strawberry_django.field() @strawberry.field def ipsec_proposal(self, id: int) -> IPSecProposalType: - return models.IPSecProposal.objects.get(id=id) + return models.IPSecProposal.objects.get(pk=id) ipsec_proposal_list: List[IPSecProposalType] = strawberry_django.field() @strawberry.field def l2vpn(self, id: int) -> L2VPNType: - return models.L2VPN.objects.get(id=id) + return models.L2VPN.objects.get(pk=id) l2vpn_list: List[L2VPNType] = strawberry_django.field() @strawberry.field def l2vpn_termination(self, id: int) -> L2VPNTerminationType: - return models.L2VPNTermination.objects.get(id=id) + return models.L2VPNTermination.objects.get(pk=id) l2vpn_termination_list: List[L2VPNTerminationType] = strawberry_django.field() @strawberry.field def tunnel(self, id: int) -> TunnelType: - return models.Tunnel.objects.get(id=id) + return models.Tunnel.objects.get(pk=id) tunnel_list: List[TunnelType] = strawberry_django.field() @strawberry.field def tunnel_group(self, id: int) -> TunnelGroupType: - return models.TunnelGroup.objects.get(id=id) + return models.TunnelGroup.objects.get(pk=id) tunnel_group_list: List[TunnelGroupType] = strawberry_django.field() @strawberry.field def tunnel_termination(self, id: int) -> TunnelTerminationType: - return models.TunnelTermination.objects.get(id=id) + return models.TunnelTermination.objects.get(pk=id) tunnel_termination_list: List[TunnelTerminationType] = strawberry_django.field() diff --git a/netbox/wireless/graphql/schema.py b/netbox/wireless/graphql/schema.py index 38f0b9c4e..80a40c063 100644 --- a/netbox/wireless/graphql/schema.py +++ b/netbox/wireless/graphql/schema.py @@ -11,15 +11,15 @@ from .types import * class WirelessQuery: @strawberry.field def wireless_lan(self, id: int) -> WirelessLANType: - return models.WirelessLAN.objects.get(id=id) + return models.WirelessLAN.objects.get(pk=id) wireless_lan_list: List[WirelessLANType] = strawberry_django.field() @strawberry.field def wireless_lan_group(self, id: int) -> WirelessLANGroupType: - return models.WirelessLANGroup.objects.get(id=id) + return models.WirelessLANGroup.objects.get(pk=id) wireless_lan_group_list: List[WirelessLANGroupType] = strawberry_django.field() @strawberry.field def wireless_link(self, id: int) -> WirelessLinkType: - return models.WirelessLink.objects.get(id=id) + return models.WirelessLink.objects.get(pk=id) wireless_link_list: List[WirelessLinkType] = strawberry_django.field() diff --git a/requirements.txt b/requirements.txt index 93aaed30c..c4b8c4173 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,7 +30,8 @@ PyYAML==6.0.1 requests==2.31.0 social-auth-app-django==5.4.0 social-auth-core[openidconnect]==4.5.3 -strawberry-graphql-django==0.33.0 +strawberry-graphql==0.220.0 +strawberry-graphql-django==0.35.1 svgwrite==1.4.3 tablib==3.5.0 tzdata==2024.1