diff --git a/netbox/circuits/graphql/schema.py b/netbox/circuits/graphql/schema.py index ac8626cc5..d532d0d4b 100644 --- a/netbox/circuits/graphql/schema.py +++ b/netbox/circuits/graphql/schema.py @@ -3,38 +3,25 @@ from typing import List import strawberry import strawberry_django -from circuits import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class CircuitsQuery: - @strawberry.field - def circuit(self, id: int) -> CircuitType: - return models.Circuit.objects.get(pk=id) + circuit: CircuitType = strawberry_django.field() circuit_list: List[CircuitType] = strawberry_django.field() - @strawberry.field - def circuit_termination(self, id: int) -> CircuitTerminationType: - return models.CircuitTermination.objects.get(pk=id) + circuit_termination: CircuitTerminationType = strawberry_django.field() circuit_termination_list: List[CircuitTerminationType] = strawberry_django.field() - @strawberry.field - def circuit_type(self, id: int) -> CircuitTypeType: - return models.CircuitType.objects.get(pk=id) + circuit_type: CircuitTypeType = strawberry_django.field() circuit_type_list: List[CircuitTypeType] = strawberry_django.field() - @strawberry.field - def provider(self, id: int) -> ProviderType: - return models.Provider.objects.get(pk=id) + provider: ProviderType = strawberry_django.field() provider_list: List[ProviderType] = strawberry_django.field() - @strawberry.field - def provider_account(self, id: int) -> ProviderAccountType: - return models.ProviderAccount.objects.get(pk=id) + provider_account: ProviderAccountType = strawberry_django.field() provider_account_list: List[ProviderAccountType] = strawberry_django.field() - @strawberry.field - def provider_network(self, id: int) -> ProviderNetworkType: - return models.ProviderNetwork.objects.get(pk=id) + provider_network: ProviderNetworkType = strawberry_django.field() provider_network_list: List[ProviderNetworkType] = strawberry_django.field() diff --git a/netbox/core/graphql/schema.py b/netbox/core/graphql/schema.py index 34135cd47..a77c57c86 100644 --- a/netbox/core/graphql/schema.py +++ b/netbox/core/graphql/schema.py @@ -3,18 +3,13 @@ from typing import List import strawberry import strawberry_django -from core import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class CoreQuery: - @strawberry.field - def data_file(self, id: int) -> DataFileType: - return models.DataFile.objects.get(pk=id) + data_file: DataFileType = strawberry_django.field() data_file_list: List[DataFileType] = strawberry_django.field() - @strawberry.field - def data_source(self, id: int) -> DataSourceType: - return models.DataSource.objects.get(pk=id) + data_source: DataSourceType = strawberry_django.field() data_source_list: List[DataSourceType] = strawberry_django.field() diff --git a/netbox/dcim/graphql/schema.py b/netbox/dcim/graphql/schema.py index c3962a87a..803970293 100644 --- a/netbox/dcim/graphql/schema.py +++ b/netbox/dcim/graphql/schema.py @@ -3,208 +3,127 @@ from typing import List import strawberry import strawberry_django -from dcim import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class DCIMQuery: - @strawberry.field - def cable(self, id: int) -> CableType: - return models.Cable.objects.get(pk=id) + cable: CableType = strawberry_django.field() cable_list: List[CableType] = strawberry_django.field() - @strawberry.field - def console_port(self, id: int) -> ConsolePortType: - return models.ConsolePort.objects.get(pk=id) + console_port: ConsolePortType = strawberry_django.field() console_port_list: List[ConsolePortType] = strawberry_django.field() - @strawberry.field - def console_port_template(self, id: int) -> ConsolePortTemplateType: - return models.ConsolePortTemplate.objects.get(pk=id) + console_port_template: ConsolePortTemplateType = strawberry_django.field() console_port_template_list: List[ConsolePortTemplateType] = strawberry_django.field() - @strawberry.field - def console_server_port(self, id: int) -> ConsoleServerPortType: - return models.ConsoleServerPort.objects.get(pk=id) + console_server_port: ConsoleServerPortType = strawberry_django.field() console_server_port_list: List[ConsoleServerPortType] = strawberry_django.field() - @strawberry.field - def console_server_port_template(self, id: int) -> ConsoleServerPortTemplateType: - return models.ConsoleServerPortTemplate.objects.get(pk=id) + console_server_port_template: ConsoleServerPortTemplateType = strawberry_django.field() console_server_port_template_list: List[ConsoleServerPortTemplateType] = strawberry_django.field() - @strawberry.field - def device(self, id: int) -> DeviceType: - return models.Device.objects.get(pk=id) + device: DeviceType = strawberry_django.field() device_list: List[DeviceType] = strawberry_django.field() - @strawberry.field - def device_bay(self, id: int) -> DeviceBayType: - return models.DeviceBay.objects.get(pk=id) + device_bay: DeviceBayType = strawberry_django.field() device_bay_list: List[DeviceBayType] = strawberry_django.field() - @strawberry.field - def device_bay_template(self, id: int) -> DeviceBayTemplateType: - return models.DeviceBayTemplate.objects.get(pk=id) + device_bay_template: DeviceBayTemplateType = strawberry_django.field() device_bay_template_list: List[DeviceBayTemplateType] = strawberry_django.field() - @strawberry.field - def device_role(self, id: int) -> DeviceRoleType: - return models.DeviceRole.objects.get(pk=id) + device_role: DeviceRoleType = strawberry_django.field() device_role_list: List[DeviceRoleType] = strawberry_django.field() - @strawberry.field - def device_type(self, id: int) -> DeviceTypeType: - return models.DeviceType.objects.get(pk=id) + device_type: DeviceTypeType = strawberry_django.field() device_type_list: List[DeviceTypeType] = strawberry_django.field() - @strawberry.field - def front_port(self, id: int) -> FrontPortType: - return models.FrontPort.objects.get(pk=id) + front_port: FrontPortType = strawberry_django.field() front_port_list: List[FrontPortType] = strawberry_django.field() - @strawberry.field - def front_port_template(self, id: int) -> FrontPortTemplateType: - return models.FrontPortTemplate.objects.get(pk=id) + front_port_template: FrontPortTemplateType = strawberry_django.field() front_port_template_list: List[FrontPortTemplateType] = strawberry_django.field() - @strawberry.field - def interface(self, id: int) -> InterfaceType: - return models.Interface.objects.get(pk=id) + interface: InterfaceType = strawberry_django.field() interface_list: List[InterfaceType] = strawberry_django.field() - @strawberry.field - def interface_template(self, id: int) -> InterfaceTemplateType: - return models.InterfaceTemplate.objects.get(pk=id) + interface_template: InterfaceTemplateType = strawberry_django.field() interface_template_list: List[InterfaceTemplateType] = strawberry_django.field() - @strawberry.field - def inventory_item(self, id: int) -> InventoryItemType: - return models.InventoryItem.objects.get(pk=id) + inventory_item: InventoryItemType = strawberry_django.field() inventory_item_list: List[InventoryItemType] = strawberry_django.field() - @strawberry.field - def inventory_item_role(self, id: int) -> InventoryItemRoleType: - return models.InventoryItemRole.objects.get(pk=id) + inventory_item_role: InventoryItemRoleType = strawberry_django.field() inventory_item_role_list: List[InventoryItemRoleType] = strawberry_django.field() - @strawberry.field - def inventory_item_template(self, id: int) -> InventoryItemTemplateType: - return models.InventoryItemTemplate.objects.get(pk=id) + inventory_item_template: InventoryItemTemplateType = strawberry_django.field() inventory_item_template_list: List[InventoryItemTemplateType] = strawberry_django.field() - @strawberry.field - def location(self, id: int) -> LocationType: - return models.Location.objects.get(pk=id) + location: LocationType = strawberry_django.field() location_list: List[LocationType] = strawberry_django.field() - @strawberry.field - def manufacturer(self, id: int) -> ManufacturerType: - return models.Manufacturer.objects.get(pk=id) + manufacturer: ManufacturerType = strawberry_django.field() manufacturer_list: List[ManufacturerType] = strawberry_django.field() - @strawberry.field - def module(self, id: int) -> ModuleType: - return models.Module.objects.get(pk=id) + module: ModuleType = strawberry_django.field() module_list: List[ModuleType] = strawberry_django.field() - @strawberry.field - def module_bay(self, id: int) -> ModuleBayType: - return models.ModuleBay.objects.get(pk=id) + module_bay: ModuleBayType = strawberry_django.field() module_bay_list: List[ModuleBayType] = strawberry_django.field() - @strawberry.field - def module_bay_template(self, id: int) -> ModuleBayTemplateType: - return models.ModuleBayTemplate.objects.get(pk=id) + module_bay_template: ModuleBayTemplateType = strawberry_django.field() module_bay_template_list: List[ModuleBayTemplateType] = strawberry_django.field() - @strawberry.field - def module_type(self, id: int) -> ModuleTypeType: - return models.ModuleType.objects.get(pk=id) + module_type: ModuleTypeType = strawberry_django.field() module_type_list: List[ModuleTypeType] = strawberry_django.field() - @strawberry.field - def platform(self, id: int) -> PlatformType: - return models.Platform.objects.get(pk=id) + platform: PlatformType = strawberry_django.field() platform_list: List[PlatformType] = strawberry_django.field() - @strawberry.field - def power_feed(self, id: int) -> PowerFeedType: - return models.PowerFeed.objects.get(pk=id) + power_feed: PowerFeedType = strawberry_django.field() power_feed_list: List[PowerFeedType] = strawberry_django.field() - @strawberry.field - def power_outlet(self, id: int) -> PowerOutletType: - return models.PowerOutlet.objects.get(pk=id) + power_outlet: PowerOutletType = strawberry_django.field() power_outlet_list: List[PowerOutletType] = strawberry_django.field() - @strawberry.field - def power_outlet_template(self, id: int) -> PowerOutletTemplateType: - return models.PowerOutletTemplate.objects.get(pk=id) + power_outlet_template: PowerOutletTemplateType = strawberry_django.field() power_outlet_template_list: List[PowerOutletTemplateType] = strawberry_django.field() - @strawberry.field - def power_panel(self, id: int) -> PowerPanelType: - return models.PowerPanel.objects.get(id=id) + power_panel: PowerPanelType = strawberry_django.field() power_panel_list: List[PowerPanelType] = strawberry_django.field() - @strawberry.field - def power_port(self, id: int) -> PowerPortType: - return models.PowerPort.objects.get(id=id) + power_port: PowerPortType = strawberry_django.field() power_port_list: List[PowerPortType] = strawberry_django.field() - @strawberry.field - def power_port_template(self, id: int) -> PowerPortTemplateType: - return models.PowerPortTemplate.objects.get(id=id) + power_port_template: PowerPortTemplateType = strawberry_django.field() power_port_template_list: List[PowerPortTemplateType] = strawberry_django.field() - @strawberry.field - def rack(self, id: int) -> RackType: - return models.Rack.objects.get(id=id) + rack: RackType = strawberry_django.field() rack_list: List[RackType] = strawberry_django.field() - @strawberry.field - def rack_reservation(self, id: int) -> RackReservationType: - return models.RackReservation.objects.get(id=id) + rack_reservation: RackReservationType = strawberry_django.field() rack_reservation_list: List[RackReservationType] = strawberry_django.field() - @strawberry.field - def rack_role(self, id: int) -> RackRoleType: - return models.RackRole.objects.get(id=id) + rack_role: RackRoleType = strawberry_django.field() rack_role_list: List[RackRoleType] = strawberry_django.field() - @strawberry.field - def rear_port(self, id: int) -> RearPortType: - return models.RearPort.objects.get(id=id) + rear_port: RearPortType = strawberry_django.field() rear_port_list: List[RearPortType] = strawberry_django.field() - @strawberry.field - def rear_port_template(self, id: int) -> RearPortTemplateType: - return models.RearPortTemplate.objects.get(id=id) + rear_port_template: RearPortTemplateType = strawberry_django.field() rear_port_template_list: List[RearPortTemplateType] = strawberry_django.field() - @strawberry.field - def region(self, id: int) -> RegionType: - return models.Region.objects.get(id=id) + region: RegionType = strawberry_django.field() region_list: List[RegionType] = strawberry_django.field() - @strawberry.field - def site(self, id: int) -> SiteType: - return models.Site.objects.get(id=id) + site: SiteType = strawberry_django.field() site_list: List[SiteType] = strawberry_django.field() - @strawberry.field - def site_group(self, id: int) -> SiteGroupType: - return models.SiteGroup.objects.get(id=id) + site_group: SiteGroupType = strawberry_django.field() site_group_list: List[SiteGroupType] = strawberry_django.field() - @strawberry.field - def virtual_chassis(self, id: int) -> VirtualChassisType: - return models.VirtualChassis.objects.get(id=id) + virtual_chassis: VirtualChassisType = strawberry_django.field() virtual_chassis_list: List[VirtualChassisType] = strawberry_django.field() - @strawberry.field - def virtual_device_context(self, id: int) -> VirtualDeviceContextType: - return models.VirtualDeviceContext.objects.get(id=id) + virtual_device_context: VirtualDeviceContextType = strawberry_django.field() virtual_device_context_list: List[VirtualDeviceContextType] = strawberry_django.field() diff --git a/netbox/extras/graphql/schema.py b/netbox/extras/graphql/schema.py index f78285035..b9586ab83 100644 --- a/netbox/extras/graphql/schema.py +++ b/netbox/extras/graphql/schema.py @@ -3,68 +3,43 @@ from typing import List import strawberry import strawberry_django -from extras import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class ExtrasQuery: - @strawberry.field - def config_context(self, id: int) -> ConfigContextType: - return models.ConfigContext.objects.get(pk=id) + config_context: ConfigContextType = strawberry_django.field() config_context_list: List[ConfigContextType] = strawberry_django.field() - @strawberry.field - def config_template(self, id: int) -> ConfigTemplateType: - return models.ConfigTemplate.objects.get(pk=id) + config_template: ConfigTemplateType = strawberry_django.field() config_template_list: List[ConfigTemplateType] = strawberry_django.field() - @strawberry.field - def custom_field(self, id: int) -> CustomFieldType: - return models.CustomField.objects.get(pk=id) + custom_field: CustomFieldType = strawberry_django.field() custom_field_list: List[CustomFieldType] = strawberry_django.field() - @strawberry.field - def custom_field_choice_set(self, id: int) -> CustomFieldChoiceSetType: - return models.CustomFieldChoiceSet.objects.get(pk=id) + custom_field_choice_set: CustomFieldChoiceSetType = strawberry_django.field() custom_field_choice_set_list: List[CustomFieldChoiceSetType] = strawberry_django.field() - @strawberry.field - def custom_link(self, id: int) -> CustomLinkType: - return models.CustomLink.objects.get(pk=id) + custom_link: CustomLinkType = strawberry_django.field() custom_link_list: List[CustomLinkType] = strawberry_django.field() - @strawberry.field - def export_template(self, id: int) -> ExportTemplateType: - return models.ExportTemplate.objects.get(pk=id) + export_template: ExportTemplateType = strawberry_django.field() export_template_list: List[ExportTemplateType] = strawberry_django.field() - @strawberry.field - def image_attachment(self, id: int) -> ImageAttachmentType: - return models.ImageAttachment.objects.get(pk=id) + image_attachment: ImageAttachmentType = strawberry_django.field() image_attachment_list: List[ImageAttachmentType] = strawberry_django.field() - @strawberry.field - def saved_filter(self, id: int) -> SavedFilterType: - return models.SavedFilter.objects.get(pk=id) + saved_filter: SavedFilterType = strawberry_django.field() saved_filter_list: List[SavedFilterType] = strawberry_django.field() - @strawberry.field - def journal_entry(self, id: int) -> JournalEntryType: - return models.JournalEntry.objects.get(pk=id) + journal_entry: JournalEntryType = strawberry_django.field() journal_entry_list: List[JournalEntryType] = strawberry_django.field() - @strawberry.field - def tag(self, id: int) -> TagType: - return models.Tag.objects.get(pk=id) + tag: TagType = strawberry_django.field() tag_list: List[TagType] = strawberry_django.field() - @strawberry.field - def webhook(self, id: int) -> WebhookType: - return models.Webhook.objects.get(pk=id) + webhook: WebhookType = strawberry_django.field() webhook_list: List[WebhookType] = strawberry_django.field() - @strawberry.field - def event_rule(self, id: int) -> EventRuleType: - return models.EventRule.objects.get(pk=id) + event_rule: EventRuleType = strawberry_django.field() event_rule_list: List[EventRuleType] = strawberry_django.field() diff --git a/netbox/ipam/graphql/schema.py b/netbox/ipam/graphql/schema.py index c02788c3a..072f8cbcd 100644 --- a/netbox/ipam/graphql/schema.py +++ b/netbox/ipam/graphql/schema.py @@ -3,88 +3,55 @@ from typing import List import strawberry import strawberry_django -from ipam import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class IPAMQuery: - @strawberry.field - def asn(self, id: int) -> ASNType: - return models.ASN.objects.get(pk=id) + asn: ASNType = strawberry_django.field() asn_list: List[ASNType] = strawberry_django.field() - @strawberry.field - def asn_range(self, id: int) -> ASNRangeType: - return models.ASNRange.objects.get(pk=id) + asn_range: ASNRangeType = strawberry_django.field() asn_range_list: List[ASNRangeType] = strawberry_django.field() - @strawberry.field - def aggregate(self, id: int) -> AggregateType: - return models.Aggregate.objects.get(pk=id) + aggregate: AggregateType = strawberry_django.field() aggregate_list: List[AggregateType] = strawberry_django.field() - @strawberry.field - def ip_address(self, id: int) -> IPAddressType: - return models.IPAddress.objects.get(pk=id) + ip_address: IPAddressType = strawberry_django.field() ip_address_list: List[IPAddressType] = strawberry_django.field() - @strawberry.field - def ip_range(self, id: int) -> IPRangeType: - return models.IPRange.objects.get(pk=id) + ip_range: IPRangeType = strawberry_django.field() ip_range_list: List[IPRangeType] = strawberry_django.field() - @strawberry.field - def prefix(self, id: int) -> PrefixType: - return models.Prefix.objects.get(pk=id) + prefix: PrefixType = strawberry_django.field() prefix_list: List[PrefixType] = strawberry_django.field() - @strawberry.field - def rir(self, id: int) -> RIRType: - return models.RIR.objects.get(pk=id) + rir: RIRType = strawberry_django.field() rir_list: List[RIRType] = strawberry_django.field() - @strawberry.field - def role(self, id: int) -> RoleType: - return models.Role.objects.get(pk=id) + role: RoleType = strawberry_django.field() role_list: List[RoleType] = strawberry_django.field() - @strawberry.field - def route_target(self, id: int) -> RouteTargetType: - return models.RouteTarget.objects.get(pk=id) + route_target: RouteTargetType = strawberry_django.field() route_target_list: List[RouteTargetType] = strawberry_django.field() - @strawberry.field - def service(self, id: int) -> ServiceType: - return models.Service.objects.get(pk=id) + service: ServiceType = strawberry_django.field() service_list: List[ServiceType] = strawberry_django.field() - @strawberry.field - def service_template(self, id: int) -> ServiceTemplateType: - return models.ServiceTemplate.objects.get(pk=id) + service_template: ServiceTemplateType = strawberry_django.field() service_template_list: List[ServiceTemplateType] = strawberry_django.field() - @strawberry.field - def fhrp_group(self, id: int) -> FHRPGroupType: - return models.FHRPGroup.objects.get(pk=id) + fhrp_group: FHRPGroupType = strawberry_django.field() fhrp_group_list: List[FHRPGroupType] = strawberry_django.field() - @strawberry.field - def fhrp_group_assignment(self, id: int) -> FHRPGroupAssignmentType: - return models.FHRPGroupAssignment.objects.get(pk=id) + fhrp_group_assignment: FHRPGroupAssignmentType = strawberry_django.field() fhrp_group_assignment_list: List[FHRPGroupAssignmentType] = strawberry_django.field() - @strawberry.field - def vlan(self, id: int) -> VLANType: - return models.VLAN.objects.get(pk=id) + vlan: VLANType = strawberry_django.field() vlan_list: List[VLANType] = strawberry_django.field() - @strawberry.field - def vlan_group(self, id: int) -> VLANGroupType: - return models.VLANGroup.objects.get(pk=id) + vlan_group: VLANGroupType = strawberry_django.field() vlan_group_list: List[VLANGroupType] = strawberry_django.field() - @strawberry.field - def vrf(self, id: int) -> VRFType: - return models.VRF.objects.get(pk=id) + vrf: VRFType = strawberry_django.field() vrf_list: List[VRFType] = strawberry_django.field() diff --git a/netbox/netbox/settings.py b/netbox/netbox/settings.py index 869b6be31..7c8e561a8 100644 --- a/netbox/netbox/settings.py +++ b/netbox/netbox/settings.py @@ -763,6 +763,7 @@ LOCALE_PATHS = ( # Strawberry (GraphQL) # STRAWBERRY_DJANGO = { + "DEFAULT_PK_FIELD_NAME": "id", "TYPE_DESCRIPTION_FROM_MODEL_DOCSTRING": True, "USE_DEPRECATED_FILTERS": True, } diff --git a/netbox/netbox/tests/dummy_plugin/graphql.py b/netbox/netbox/tests/dummy_plugin/graphql.py index 2651f4e9e..a8bbfcea2 100644 --- a/netbox/netbox/tests/dummy_plugin/graphql.py +++ b/netbox/netbox/tests/dummy_plugin/graphql.py @@ -13,11 +13,9 @@ class DummyModelType: pass -@strawberry.type +@strawberry.type(name="Query") class DummyQuery: - @strawberry.field - def dummymodel(self, id: int) -> DummyModelType: - return None + dummymodel: DummyModelType = strawberry_django.field() dummymodel_list: List[DummyModelType] = strawberry_django.field() diff --git a/netbox/tenancy/graphql/schema.py b/netbox/tenancy/graphql/schema.py index 79f8660d4..857d8ddeb 100644 --- a/netbox/tenancy/graphql/schema.py +++ b/netbox/tenancy/graphql/schema.py @@ -3,38 +3,25 @@ from typing import List import strawberry import strawberry_django -from tenancy import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class TenancyQuery: - @strawberry.field - def tenant(self, id: int) -> TenantType: - return models.Tenant.objects.get(pk=id) + tenant: TenantType = strawberry_django.field() tenant_list: List[TenantType] = strawberry_django.field() - @strawberry.field - def tenant_group(self, id: int) -> TenantGroupType: - return models.TenantGroup.objects.get(pk=id) + tenant_group: TenantGroupType = strawberry_django.field() tenant_group_list: List[TenantGroupType] = strawberry_django.field() - @strawberry.field - def contact(self, id: int) -> ContactType: - return models.Contact.objects.get(pk=id) + contact: ContactType = strawberry_django.field() contact_list: List[ContactType] = strawberry_django.field() - @strawberry.field - def contact_role(self, id: int) -> ContactRoleType: - return models.ContactRole.objects.get(pk=id) + contact_role: ContactRoleType = strawberry_django.field() contact_role_list: List[ContactRoleType] = strawberry_django.field() - @strawberry.field - def contact_group(self, id: int) -> ContactGroupType: - return models.ContactGroup.objects.get(pk=id) + contact_group: ContactGroupType = strawberry_django.field() contact_group_list: List[ContactGroupType] = strawberry_django.field() - @strawberry.field - def contact_assignment(self, id: int) -> ContactAssignmentType: - return models.ContactAssignment.objects.get(pk=id) + contact_assignment: ContactAssignmentType = strawberry_django.field() contact_assignment_list: List[ContactAssignmentType] = strawberry_django.field() diff --git a/netbox/users/graphql/schema.py b/netbox/users/graphql/schema.py index 840887ad2..b59266c57 100644 --- a/netbox/users/graphql/schema.py +++ b/netbox/users/graphql/schema.py @@ -1,21 +1,15 @@ from typing import List + import strawberry import strawberry_django -from django.contrib.auth import get_user_model -from django.contrib.auth.models import Group -from users import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class UsersQuery: - @strawberry.field - def group(self, id: int) -> GroupType: - return models.Group.objects.get(pk=id) + group: GroupType = strawberry_django.field() group_list: List[GroupType] = strawberry_django.field() - @strawberry.field - def user(self, id: int) -> UserType: - return models.User.objects.get(pk=id) + user: UserType = strawberry_django.field() user_list: List[UserType] = strawberry_django.field() diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 7bb349a66..a3cfb9b2e 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -16,7 +16,7 @@ from extras.models import ObjectChange from users.models import ObjectPermission, Token from utilities.api import get_graphql_type_for_model from .base import ModelTestCase -from .utils import disable_warnings +from .utils import disable_logging, disable_warnings from ipam.graphql.types import IPAddressFamilyType from strawberry.types.lazy_type import LazyType @@ -523,7 +523,6 @@ class APIViewTestCases: return self._build_query_with_filter(name, filter_string) @override_settings(LOGIN_REQUIRED=True) - @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user']) def test_graphql_get_object(self): url = reverse('graphql') field_name = self._get_graphql_base_name() @@ -531,57 +530,85 @@ class APIViewTestCases: query = self._build_query(field_name, id=object_id) # Non-authenticated requests should fail + header = { + 'HTTP_ACCEPT': 'application/json', + } with disable_warnings('django.request'): - header = { - 'HTTP_ACCEPT': 'application/json', - } - self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN) + response = self.client.post(url, data={'query': query}, format="json", **header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) - # Add object-level permission + # Add constrained permission obj_perm = ObjectPermission( name='Test permission', - actions=['view'] + actions=['view'], + constraints={'id': 0} # Impossible constraint ) obj_perm.save() obj_perm.users.add(self.user) obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + # Request should succeed but return empty result + with disable_logging(): + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertIn('errors', data) + self.assertIsNone(data['data']) + + # Remove permission constraint + obj_perm.constraints = None + obj_perm.save() + + # Request should return requested object response = self.client.post(url, data={'query': query}, format="json", **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) data = json.loads(response.content) self.assertNotIn('errors', data) + self.assertIsNotNone(data['data']) @override_settings(LOGIN_REQUIRED=True) - @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user']) def test_graphql_list_objects(self): url = reverse('graphql') field_name = f'{self._get_graphql_base_name()}_list' query = self._build_query(field_name) # Non-authenticated requests should fail + header = { + 'HTTP_ACCEPT': 'application/json', + } with disable_warnings('django.request'): - header = { - 'HTTP_ACCEPT': 'application/json', - } - self.assertHttpStatus(self.client.post(url, data={'query': query}, format="json", **header), status.HTTP_403_FORBIDDEN) + response = self.client.post(url, data={'query': query}, format="json", **header) + self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN) - # Add object-level permission + # Add constrained permission obj_perm = ObjectPermission( name='Test permission', - actions=['view'] + actions=['view'], + constraints={'id': 0} # Impossible constraint ) obj_perm.save() obj_perm.users.add(self.user) obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model)) + # Request should succeed but return empty results list response = self.client.post(url, data={'query': query}, format="json", **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) data = json.loads(response.content) self.assertNotIn('errors', data) - self.assertGreater(len(data['data'][field_name]), 0) + self.assertEqual(len(data['data'][field_name]), 0) + + # Remove permission constraint + obj_perm.constraints = None + obj_perm.save() + + # Request should return all objects + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + self.assertEqual(len(data['data'][field_name]), self.model.objects.count()) @override_settings(LOGIN_REQUIRED=True) - @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user']) def test_graphql_filter_objects(self): if not hasattr(self, 'graphql_filter'): return diff --git a/netbox/utilities/testing/utils.py b/netbox/utilities/testing/utils.py index 59bce2b7c..987e5ec35 100644 --- a/netbox/utilities/testing/utils.py +++ b/netbox/utilities/testing/utils.py @@ -107,6 +107,16 @@ def disable_warnings(logger_name): logger.setLevel(current_level) +@contextmanager +def disable_logging(level=logging.CRITICAL): + """ + Temporarily suppress log messages at or below the specified level (default: critical). + """ + logging.disable(level) + yield + logging.disable(logging.NOTSET) + + # # Custom field testing # diff --git a/netbox/virtualization/graphql/schema.py b/netbox/virtualization/graphql/schema.py index 72d83155d..212425814 100644 --- a/netbox/virtualization/graphql/schema.py +++ b/netbox/virtualization/graphql/schema.py @@ -3,38 +3,25 @@ from typing import List import strawberry import strawberry_django -from virtualization import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class VirtualizationQuery: - @strawberry.field - def cluster(self, id: int) -> ClusterType: - return models.Cluster.objects.get(pk=id) + cluster: ClusterType = strawberry_django.field() cluster_list: List[ClusterType] = strawberry_django.field() - @strawberry.field - def cluster_group(self, id: int) -> ClusterGroupType: - return models.ClusterGroup.objects.get(pk=id) + cluster_group: ClusterGroupType = strawberry_django.field() cluster_group_list: List[ClusterGroupType] = strawberry_django.field() - @strawberry.field - def cluster_type(self, id: int) -> ClusterTypeType: - return models.ClusterType.objects.get(pk=id) + cluster_type: ClusterTypeType = strawberry_django.field() cluster_type_list: List[ClusterTypeType] = strawberry_django.field() - @strawberry.field - def virtual_machine(self, id: int) -> VirtualMachineType: - return models.VirtualMachine.objects.get(pk=id) + virtual_machine: VirtualMachineType = strawberry_django.field() virtual_machine_list: List[VirtualMachineType] = strawberry_django.field() - @strawberry.field - def vm_interface(self, id: int) -> VMInterfaceType: - return models.VMInterface.objects.get(pk=id) + vm_interface: VMInterfaceType = strawberry_django.field() vm_interface_list: List[VMInterfaceType] = strawberry_django.field() - @strawberry.field - def virtual_disk(self, id: int) -> VirtualDiskType: - return models.VirtualDisk.objects.get(pk=id) + virtual_disk: VirtualDiskType = strawberry_django.field() virtual_disk_list: List[VirtualDiskType] = strawberry_django.field() diff --git a/netbox/vpn/graphql/schema.py b/netbox/vpn/graphql/schema.py index f37e444a2..06ccc577d 100644 --- a/netbox/vpn/graphql/schema.py +++ b/netbox/vpn/graphql/schema.py @@ -3,58 +3,37 @@ from typing import List import strawberry import strawberry_django -from vpn import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class VPNQuery: - @strawberry.field - def ike_policy(self, id: int) -> IKEPolicyType: - return models.IKEPolicy.objects.get(pk=id) + ike_policy: IKEPolicyType = strawberry_django.field() ike_policy_list: List[IKEPolicyType] = strawberry_django.field() - @strawberry.field - def ike_proposal(self, id: int) -> IKEProposalType: - return models.IKEProposal.objects.get(pk=id) + ike_proposal: IKEProposalType = strawberry_django.field() ike_proposal_list: List[IKEProposalType] = strawberry_django.field() - @strawberry.field - def ipsec_policy(self, id: int) -> IPSecPolicyType: - return models.IPSecPolicy.objects.get(pk=id) + ipsec_policy: IPSecPolicyType = strawberry_django.field() ipsec_policy_list: List[IPSecPolicyType] = strawberry_django.field() - @strawberry.field - def ipsec_profile(self, id: int) -> IPSecProfileType: - return models.IPSecProfile.objects.get(pk=id) + ipsec_profile: IPSecProfileType = strawberry_django.field() ipsec_profile_list: List[IPSecProfileType] = strawberry_django.field() - @strawberry.field - def ipsec_proposal(self, id: int) -> IPSecProposalType: - return models.IPSecProposal.objects.get(pk=id) + ipsec_proposal: IPSecProposalType = strawberry_django.field() ipsec_proposal_list: List[IPSecProposalType] = strawberry_django.field() - @strawberry.field - def l2vpn(self, id: int) -> L2VPNType: - return models.L2VPN.objects.get(pk=id) + l2vpn: L2VPNType = strawberry_django.field() l2vpn_list: List[L2VPNType] = strawberry_django.field() - @strawberry.field - def l2vpn_termination(self, id: int) -> L2VPNTerminationType: - return models.L2VPNTermination.objects.get(pk=id) + l2vpn_termination: L2VPNTerminationType = strawberry_django.field() l2vpn_termination_list: List[L2VPNTerminationType] = strawberry_django.field() - @strawberry.field - def tunnel(self, id: int) -> TunnelType: - return models.Tunnel.objects.get(pk=id) + tunnel: TunnelType = strawberry_django.field() tunnel_list: List[TunnelType] = strawberry_django.field() - @strawberry.field - def tunnel_group(self, id: int) -> TunnelGroupType: - return models.TunnelGroup.objects.get(pk=id) + tunnel_group: TunnelGroupType = strawberry_django.field() tunnel_group_list: List[TunnelGroupType] = strawberry_django.field() - @strawberry.field - def tunnel_termination(self, id: int) -> TunnelTerminationType: - return models.TunnelTermination.objects.get(pk=id) + tunnel_termination: TunnelTerminationType = strawberry_django.field() tunnel_termination_list: List[TunnelTerminationType] = strawberry_django.field() diff --git a/netbox/wireless/graphql/schema.py b/netbox/wireless/graphql/schema.py index 80a40c063..4f176031f 100644 --- a/netbox/wireless/graphql/schema.py +++ b/netbox/wireless/graphql/schema.py @@ -3,23 +3,16 @@ from typing import List import strawberry import strawberry_django -from wireless import models from .types import * -@strawberry.type +@strawberry.type(name="Query") class WirelessQuery: - @strawberry.field - def wireless_lan(self, id: int) -> WirelessLANType: - return models.WirelessLAN.objects.get(pk=id) + wireless_lan: WirelessLANType = strawberry_django.field() wireless_lan_list: List[WirelessLANType] = strawberry_django.field() - @strawberry.field - def wireless_lan_group(self, id: int) -> WirelessLANGroupType: - return models.WirelessLANGroup.objects.get(pk=id) + wireless_lan_group: WirelessLANGroupType = strawberry_django.field() wireless_lan_group_list: List[WirelessLANGroupType] = strawberry_django.field() - @strawberry.field - def wireless_link(self, id: int) -> WirelessLinkType: - return models.WirelessLink.objects.get(pk=id) + wireless_link: WirelessLinkType = strawberry_django.field() wireless_link_list: List[WirelessLinkType] = strawberry_django.field()