From 9eafac8adec9806c36b0cb8691d229f1dbf58f36 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Mon, 30 Jan 2023 16:18:23 -0500 Subject: [PATCH] Refactor fetch methods into backend classes --- netbox/core/data_backends.py | 82 ++++++++++++++++++++ netbox/core/models/data.py | 143 +++++++++++------------------------ netbox/core/utils.py | 14 ---- 3 files changed, 127 insertions(+), 112 deletions(-) create mode 100644 netbox/core/data_backends.py delete mode 100644 netbox/core/utils.py diff --git a/netbox/core/data_backends.py b/netbox/core/data_backends.py new file mode 100644 index 000000000..c14f836b3 --- /dev/null +++ b/netbox/core/data_backends.py @@ -0,0 +1,82 @@ +import logging +import subprocess +import tempfile +from contextlib import contextmanager +from urllib.parse import quote, urlunparse, urlparse + +from django.conf import settings + +from .exceptions import SyncError + +__all__ = ( + 'LocalBakend', + 'GitBackend', +) + +logger = logging.getLogger('netbox.data_backends') + + +class DataBackend: + + def __init__(self, url, **kwargs): + self.url = url + self.params = kwargs + + @property + def url_scheme(self): + return urlparse(self.url).scheme.lower() + + @contextmanager + def fetch(self): + raise NotImplemented() + + +class LocalBakend(DataBackend): + + @contextmanager + def fetch(self): + logger.debug(f"Data source type is local; skipping fetch") + local_path = urlparse(self.url).path # Strip file:// scheme + + yield local_path + + +class GitBackend(DataBackend): + + @contextmanager + def fetch(self): + local_path = tempfile.TemporaryDirectory() + + # Add authentication credentials to URL (if specified) + username = self.params.get('username') + password = self.params.get('password') + if username and password: + url_components = list(urlparse(self.url)) + # Prepend username & password to netloc + url_components[1] = quote(f'{username}@{password}:') + url_components[1] + url = urlunparse(url_components) + else: + url = self.url + + # Compile git arguments + args = ['git', 'clone', '--depth', '1'] + if branch := self.params.get('branch'): + args.extend(['--branch', branch]) + args.extend([url, local_path.name]) + + # Prep environment variables + env_vars = {} + if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'): + env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme) + + logger.debug(f"Cloning git repo: {' '.join(args)}") + try: + subprocess.run(args, check=True, capture_output=True, env=env_vars) + except subprocess.CalledProcessError as e: + raise SyncError( + f"Fetching remote data failed: {e.stderr}" + ) + + yield local_path.name + + local_path.cleanup() diff --git a/netbox/core/models/data.py b/netbox/core/models/data.py index 8fbaa7b8a..96a39249c 100644 --- a/netbox/core/models/data.py +++ b/netbox/core/models/data.py @@ -1,11 +1,8 @@ import logging import os -import subprocess -import tempfile from fnmatch import fnmatchcase -from urllib.parse import quote, urlunparse, urlparse +from urllib.parse import urlparse -from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from django.core.validators import RegexValidator @@ -20,8 +17,8 @@ from netbox.models import ChangeLoggedModel from utilities.files import sha256_hash from utilities.querysets import RestrictedQuerySet from ..choices import * +from ..data_backends import GitBackend, LocalBakend from ..exceptions import SyncError -from ..utils import FakeTempDirectory __all__ = ( 'DataSource', @@ -135,6 +132,14 @@ class DataSource(ChangeLoggedModel): return job_result + def get_backend(self): + backend_cls = { + DataSourceTypeChoices.LOCAL: LocalBakend, + DataSourceTypeChoices.GIT: GitBackend, + }.get(self.type) + + return backend_cls(self.url) + def sync(self): """ Create/update/delete child DataFiles as necessary to synchronize with the remote source. @@ -145,112 +150,54 @@ class DataSource(ChangeLoggedModel): self.status = DataSourceStatusChoices.SYNCING DataSource.objects.filter(pk=self.pk).update(status=self.status) - # Replicate source data locally (if needed) - local_path = self.fetch() + # Replicate source data locally + backend = self.get_backend() + with backend.fetch() as local_path: - logger.debug(f'Syncing files from source root {local_path.name}') - data_files = self.datafiles.all() - known_paths = {df.path for df in data_files} - logger.debug(f'Starting with {len(known_paths)} known files') + logger.debug(f'Syncing files from source root {local_path}') + data_files = self.datafiles.all() + known_paths = {df.path for df in data_files} + logger.debug(f'Starting with {len(known_paths)} known files') - # Check for any updated/deleted files - updated_files = [] - deleted_file_ids = [] - for datafile in data_files: + # Check for any updated/deleted files + updated_files = [] + deleted_file_ids = [] + for datafile in data_files: - try: - if datafile.refresh_from_disk(source_root=local_path.name): - updated_files.append(datafile) - except FileNotFoundError: - # File no longer exists - deleted_file_ids.append(datafile.pk) - continue + try: + if datafile.refresh_from_disk(source_root=local_path): + updated_files.append(datafile) + except FileNotFoundError: + # File no longer exists + deleted_file_ids.append(datafile.pk) + continue - # Bulk update modified files - updated_count = DataFile.objects.bulk_update(updated_files, ['hash']) - logger.debug(f"Updated {updated_count} files") + # Bulk update modified files + updated_count = DataFile.objects.bulk_update(updated_files, ['hash']) + logger.debug(f"Updated {updated_count} files") - # Bulk delete deleted files - deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete() - logger.debug(f"Deleted {updated_count} files") + # Bulk delete deleted files + deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete() + logger.debug(f"Deleted {updated_count} files") - # Walk the local replication to find new files - new_paths = self._walk(local_path.name) - known_paths + # Walk the local replication to find new files + new_paths = self._walk(local_path) - known_paths - # Bulk create new files - new_datafiles = [] - for path in new_paths: - datafile = DataFile(source=self, path=path) - datafile.refresh_from_disk(source_root=local_path.name) - datafile.full_clean() - new_datafiles.append(datafile) - created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100)) - logger.debug(f"Created {created_count} data files") + # Bulk create new files + new_datafiles = [] + for path in new_paths: + datafile = DataFile(source=self, path=path) + datafile.refresh_from_disk(source_root=local_path) + datafile.full_clean() + new_datafiles.append(datafile) + created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100)) + logger.debug(f"Created {created_count} data files") # Update status & last_synced time self.status = DataSourceStatusChoices.COMPLETED self.last_updated = timezone.now() DataSource.objects.filter(pk=self.pk).update(status=self.status, last_updated=self.last_updated) - local_path.cleanup() - - def fetch(self): - """ - Replicate the file structure from the remote data source and return the local path. - """ - logger.debug(f"Fetching source data for {self} ({self.get_type_display()})") - try: - fetch_method = getattr(self, f'fetch_{self.type}') - except AttributeError: - raise NotImplemented(f"fetch() not yet supported for {self.get_type_display()} data sources") - - return fetch_method() - - def fetch_local(self, path): - """ - Skip fetching for local paths; return the source path directly. - """ - logger.debug(f"Data source type is local; skipping fetch") - local_path = urlparse(self.url).path - - return FakeTempDirectory(local_path) - - def fetch_git(self): - """ - Perform a shallow clone of the remote repository using the `git` executable. - """ - local_path = tempfile.TemporaryDirectory() - - # Add authentication credentials to URL (if specified) - if self.username and self.password: - url_components = list(urlparse(self.url)) - # Prepend username & password to netloc - url_components[1] = quote(f'{self.username}@{self.password}:') + url_components[1] - url = urlunparse(url_components) - else: - url = self.url - - # Compile git arguments - args = ['git', 'clone', '--depth', '1'] - if self.git_branch: - args.extend(['--branch', self.git_branch]) - args.extend([url, local_path.name]) - - # Prep environment variables - env_vars = {} - if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'): - env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme) - - logger.debug(f"Cloning git repo: {' '.join(args)}") - try: - subprocess.run(args, check=True, capture_output=True, env=env_vars) - except subprocess.CalledProcessError as e: - raise SyncError( - f"Fetching remote data failed: {e.stderr}" - ) - - return local_path - def _walk(self, root): """ Return a set of all non-excluded files within the root path. diff --git a/netbox/core/utils.py b/netbox/core/utils.py deleted file mode 100644 index 92c1786f6..000000000 --- a/netbox/core/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -__all__ = ( - 'FakeTempDirectory', -) - - -class FakeTempDirectory: - """ - Mimic tempfile.TemporaryDirectory to represent a real local path. - """ - def __init__(self, name): - self.name = name - - def cleanup(self): - pass