Refactor fetch methods into backend classes

This commit is contained in:
jeremystretch 2023-01-30 16:18:23 -05:00
parent d373729f1b
commit 9eafac8ade
3 changed files with 127 additions and 112 deletions

View File

@ -0,0 +1,82 @@
import logging
import subprocess
import tempfile
from contextlib import contextmanager
from urllib.parse import quote, urlunparse, urlparse
from django.conf import settings
from .exceptions import SyncError
__all__ = (
'LocalBakend',
'GitBackend',
)
logger = logging.getLogger('netbox.data_backends')
class DataBackend:
def __init__(self, url, **kwargs):
self.url = url
self.params = kwargs
@property
def url_scheme(self):
return urlparse(self.url).scheme.lower()
@contextmanager
def fetch(self):
raise NotImplemented()
class LocalBakend(DataBackend):
@contextmanager
def fetch(self):
logger.debug(f"Data source type is local; skipping fetch")
local_path = urlparse(self.url).path # Strip file:// scheme
yield local_path
class GitBackend(DataBackend):
@contextmanager
def fetch(self):
local_path = tempfile.TemporaryDirectory()
# Add authentication credentials to URL (if specified)
username = self.params.get('username')
password = self.params.get('password')
if username and password:
url_components = list(urlparse(self.url))
# Prepend username & password to netloc
url_components[1] = quote(f'{username}@{password}:') + url_components[1]
url = urlunparse(url_components)
else:
url = self.url
# Compile git arguments
args = ['git', 'clone', '--depth', '1']
if branch := self.params.get('branch'):
args.extend(['--branch', branch])
args.extend([url, local_path.name])
# Prep environment variables
env_vars = {}
if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme)
logger.debug(f"Cloning git repo: {' '.join(args)}")
try:
subprocess.run(args, check=True, capture_output=True, env=env_vars)
except subprocess.CalledProcessError as e:
raise SyncError(
f"Fetching remote data failed: {e.stderr}"
)
yield local_path.name
local_path.cleanup()

View File

@ -1,11 +1,8 @@
import logging import logging
import os import os
import subprocess
import tempfile
from fnmatch import fnmatchcase from fnmatch import fnmatchcase
from urllib.parse import quote, urlunparse, urlparse from urllib.parse import urlparse
from django.conf import settings
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
@ -20,8 +17,8 @@ from netbox.models import ChangeLoggedModel
from utilities.files import sha256_hash from utilities.files import sha256_hash
from utilities.querysets import RestrictedQuerySet from utilities.querysets import RestrictedQuerySet
from ..choices import * from ..choices import *
from ..data_backends import GitBackend, LocalBakend
from ..exceptions import SyncError from ..exceptions import SyncError
from ..utils import FakeTempDirectory
__all__ = ( __all__ = (
'DataSource', 'DataSource',
@ -135,6 +132,14 @@ class DataSource(ChangeLoggedModel):
return job_result return job_result
def get_backend(self):
backend_cls = {
DataSourceTypeChoices.LOCAL: LocalBakend,
DataSourceTypeChoices.GIT: GitBackend,
}.get(self.type)
return backend_cls(self.url)
def sync(self): def sync(self):
""" """
Create/update/delete child DataFiles as necessary to synchronize with the remote source. Create/update/delete child DataFiles as necessary to synchronize with the remote source.
@ -145,112 +150,54 @@ class DataSource(ChangeLoggedModel):
self.status = DataSourceStatusChoices.SYNCING self.status = DataSourceStatusChoices.SYNCING
DataSource.objects.filter(pk=self.pk).update(status=self.status) DataSource.objects.filter(pk=self.pk).update(status=self.status)
# Replicate source data locally (if needed) # Replicate source data locally
local_path = self.fetch() backend = self.get_backend()
with backend.fetch() as local_path:
logger.debug(f'Syncing files from source root {local_path.name}') logger.debug(f'Syncing files from source root {local_path}')
data_files = self.datafiles.all() data_files = self.datafiles.all()
known_paths = {df.path for df in data_files} known_paths = {df.path for df in data_files}
logger.debug(f'Starting with {len(known_paths)} known files') logger.debug(f'Starting with {len(known_paths)} known files')
# Check for any updated/deleted files # Check for any updated/deleted files
updated_files = [] updated_files = []
deleted_file_ids = [] deleted_file_ids = []
for datafile in data_files: for datafile in data_files:
try: try:
if datafile.refresh_from_disk(source_root=local_path.name): if datafile.refresh_from_disk(source_root=local_path):
updated_files.append(datafile) updated_files.append(datafile)
except FileNotFoundError: except FileNotFoundError:
# File no longer exists # File no longer exists
deleted_file_ids.append(datafile.pk) deleted_file_ids.append(datafile.pk)
continue continue
# Bulk update modified files # Bulk update modified files
updated_count = DataFile.objects.bulk_update(updated_files, ['hash']) updated_count = DataFile.objects.bulk_update(updated_files, ['hash'])
logger.debug(f"Updated {updated_count} files") logger.debug(f"Updated {updated_count} files")
# Bulk delete deleted files # Bulk delete deleted files
deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete() deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete()
logger.debug(f"Deleted {updated_count} files") logger.debug(f"Deleted {updated_count} files")
# Walk the local replication to find new files # Walk the local replication to find new files
new_paths = self._walk(local_path.name) - known_paths new_paths = self._walk(local_path) - known_paths
# Bulk create new files # Bulk create new files
new_datafiles = [] new_datafiles = []
for path in new_paths: for path in new_paths:
datafile = DataFile(source=self, path=path) datafile = DataFile(source=self, path=path)
datafile.refresh_from_disk(source_root=local_path.name) datafile.refresh_from_disk(source_root=local_path)
datafile.full_clean() datafile.full_clean()
new_datafiles.append(datafile) new_datafiles.append(datafile)
created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100)) created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100))
logger.debug(f"Created {created_count} data files") logger.debug(f"Created {created_count} data files")
# Update status & last_synced time # Update status & last_synced time
self.status = DataSourceStatusChoices.COMPLETED self.status = DataSourceStatusChoices.COMPLETED
self.last_updated = timezone.now() self.last_updated = timezone.now()
DataSource.objects.filter(pk=self.pk).update(status=self.status, last_updated=self.last_updated) DataSource.objects.filter(pk=self.pk).update(status=self.status, last_updated=self.last_updated)
local_path.cleanup()
def fetch(self):
"""
Replicate the file structure from the remote data source and return the local path.
"""
logger.debug(f"Fetching source data for {self} ({self.get_type_display()})")
try:
fetch_method = getattr(self, f'fetch_{self.type}')
except AttributeError:
raise NotImplemented(f"fetch() not yet supported for {self.get_type_display()} data sources")
return fetch_method()
def fetch_local(self, path):
"""
Skip fetching for local paths; return the source path directly.
"""
logger.debug(f"Data source type is local; skipping fetch")
local_path = urlparse(self.url).path
return FakeTempDirectory(local_path)
def fetch_git(self):
"""
Perform a shallow clone of the remote repository using the `git` executable.
"""
local_path = tempfile.TemporaryDirectory()
# Add authentication credentials to URL (if specified)
if self.username and self.password:
url_components = list(urlparse(self.url))
# Prepend username & password to netloc
url_components[1] = quote(f'{self.username}@{self.password}:') + url_components[1]
url = urlunparse(url_components)
else:
url = self.url
# Compile git arguments
args = ['git', 'clone', '--depth', '1']
if self.git_branch:
args.extend(['--branch', self.git_branch])
args.extend([url, local_path.name])
# Prep environment variables
env_vars = {}
if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme)
logger.debug(f"Cloning git repo: {' '.join(args)}")
try:
subprocess.run(args, check=True, capture_output=True, env=env_vars)
except subprocess.CalledProcessError as e:
raise SyncError(
f"Fetching remote data failed: {e.stderr}"
)
return local_path
def _walk(self, root): def _walk(self, root):
""" """
Return a set of all non-excluded files within the root path. Return a set of all non-excluded files within the root path.

View File

@ -1,14 +0,0 @@
__all__ = (
'FakeTempDirectory',
)
class FakeTempDirectory:
"""
Mimic tempfile.TemporaryDirectory to represent a real local path.
"""
def __init__(self, name):
self.name = name
def cleanup(self):
pass