mirror of
https://github.com/netbox-community/netbox.git
synced 2025-08-26 09:16:10 -06:00
Refactor fetch methods into backend classes
This commit is contained in:
parent
d373729f1b
commit
9eafac8ade
82
netbox/core/data_backends.py
Normal file
82
netbox/core/data_backends.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from urllib.parse import quote, urlunparse, urlparse
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from .exceptions import SyncError
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'LocalBakend',
|
||||||
|
'GitBackend',
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger('netbox.data_backends')
|
||||||
|
|
||||||
|
|
||||||
|
class DataBackend:
|
||||||
|
|
||||||
|
def __init__(self, url, **kwargs):
|
||||||
|
self.url = url
|
||||||
|
self.params = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url_scheme(self):
|
||||||
|
return urlparse(self.url).scheme.lower()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fetch(self):
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalBakend(DataBackend):
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fetch(self):
|
||||||
|
logger.debug(f"Data source type is local; skipping fetch")
|
||||||
|
local_path = urlparse(self.url).path # Strip file:// scheme
|
||||||
|
|
||||||
|
yield local_path
|
||||||
|
|
||||||
|
|
||||||
|
class GitBackend(DataBackend):
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fetch(self):
|
||||||
|
local_path = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
# Add authentication credentials to URL (if specified)
|
||||||
|
username = self.params.get('username')
|
||||||
|
password = self.params.get('password')
|
||||||
|
if username and password:
|
||||||
|
url_components = list(urlparse(self.url))
|
||||||
|
# Prepend username & password to netloc
|
||||||
|
url_components[1] = quote(f'{username}@{password}:') + url_components[1]
|
||||||
|
url = urlunparse(url_components)
|
||||||
|
else:
|
||||||
|
url = self.url
|
||||||
|
|
||||||
|
# Compile git arguments
|
||||||
|
args = ['git', 'clone', '--depth', '1']
|
||||||
|
if branch := self.params.get('branch'):
|
||||||
|
args.extend(['--branch', branch])
|
||||||
|
args.extend([url, local_path.name])
|
||||||
|
|
||||||
|
# Prep environment variables
|
||||||
|
env_vars = {}
|
||||||
|
if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
|
||||||
|
env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme)
|
||||||
|
|
||||||
|
logger.debug(f"Cloning git repo: {' '.join(args)}")
|
||||||
|
try:
|
||||||
|
subprocess.run(args, check=True, capture_output=True, env=env_vars)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise SyncError(
|
||||||
|
f"Fetching remote data failed: {e.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield local_path.name
|
||||||
|
|
||||||
|
local_path.cleanup()
|
@ -1,11 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from fnmatch import fnmatchcase
|
from fnmatch import fnmatchcase
|
||||||
from urllib.parse import quote, urlunparse, urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from django.conf import settings
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.core.validators import RegexValidator
|
from django.core.validators import RegexValidator
|
||||||
@ -20,8 +17,8 @@ from netbox.models import ChangeLoggedModel
|
|||||||
from utilities.files import sha256_hash
|
from utilities.files import sha256_hash
|
||||||
from utilities.querysets import RestrictedQuerySet
|
from utilities.querysets import RestrictedQuerySet
|
||||||
from ..choices import *
|
from ..choices import *
|
||||||
|
from ..data_backends import GitBackend, LocalBakend
|
||||||
from ..exceptions import SyncError
|
from ..exceptions import SyncError
|
||||||
from ..utils import FakeTempDirectory
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'DataSource',
|
'DataSource',
|
||||||
@ -135,6 +132,14 @@ class DataSource(ChangeLoggedModel):
|
|||||||
|
|
||||||
return job_result
|
return job_result
|
||||||
|
|
||||||
|
def get_backend(self):
|
||||||
|
backend_cls = {
|
||||||
|
DataSourceTypeChoices.LOCAL: LocalBakend,
|
||||||
|
DataSourceTypeChoices.GIT: GitBackend,
|
||||||
|
}.get(self.type)
|
||||||
|
|
||||||
|
return backend_cls(self.url)
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
"""
|
"""
|
||||||
Create/update/delete child DataFiles as necessary to synchronize with the remote source.
|
Create/update/delete child DataFiles as necessary to synchronize with the remote source.
|
||||||
@ -145,112 +150,54 @@ class DataSource(ChangeLoggedModel):
|
|||||||
self.status = DataSourceStatusChoices.SYNCING
|
self.status = DataSourceStatusChoices.SYNCING
|
||||||
DataSource.objects.filter(pk=self.pk).update(status=self.status)
|
DataSource.objects.filter(pk=self.pk).update(status=self.status)
|
||||||
|
|
||||||
# Replicate source data locally (if needed)
|
# Replicate source data locally
|
||||||
local_path = self.fetch()
|
backend = self.get_backend()
|
||||||
|
with backend.fetch() as local_path:
|
||||||
|
|
||||||
logger.debug(f'Syncing files from source root {local_path.name}')
|
logger.debug(f'Syncing files from source root {local_path}')
|
||||||
data_files = self.datafiles.all()
|
data_files = self.datafiles.all()
|
||||||
known_paths = {df.path for df in data_files}
|
known_paths = {df.path for df in data_files}
|
||||||
logger.debug(f'Starting with {len(known_paths)} known files')
|
logger.debug(f'Starting with {len(known_paths)} known files')
|
||||||
|
|
||||||
# Check for any updated/deleted files
|
# Check for any updated/deleted files
|
||||||
updated_files = []
|
updated_files = []
|
||||||
deleted_file_ids = []
|
deleted_file_ids = []
|
||||||
for datafile in data_files:
|
for datafile in data_files:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if datafile.refresh_from_disk(source_root=local_path.name):
|
if datafile.refresh_from_disk(source_root=local_path):
|
||||||
updated_files.append(datafile)
|
updated_files.append(datafile)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# File no longer exists
|
# File no longer exists
|
||||||
deleted_file_ids.append(datafile.pk)
|
deleted_file_ids.append(datafile.pk)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Bulk update modified files
|
# Bulk update modified files
|
||||||
updated_count = DataFile.objects.bulk_update(updated_files, ['hash'])
|
updated_count = DataFile.objects.bulk_update(updated_files, ['hash'])
|
||||||
logger.debug(f"Updated {updated_count} files")
|
logger.debug(f"Updated {updated_count} files")
|
||||||
|
|
||||||
# Bulk delete deleted files
|
# Bulk delete deleted files
|
||||||
deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete()
|
deleted_count, _ = DataFile.objects.filter(pk__in=deleted_file_ids).delete()
|
||||||
logger.debug(f"Deleted {updated_count} files")
|
logger.debug(f"Deleted {updated_count} files")
|
||||||
|
|
||||||
# Walk the local replication to find new files
|
# Walk the local replication to find new files
|
||||||
new_paths = self._walk(local_path.name) - known_paths
|
new_paths = self._walk(local_path) - known_paths
|
||||||
|
|
||||||
# Bulk create new files
|
# Bulk create new files
|
||||||
new_datafiles = []
|
new_datafiles = []
|
||||||
for path in new_paths:
|
for path in new_paths:
|
||||||
datafile = DataFile(source=self, path=path)
|
datafile = DataFile(source=self, path=path)
|
||||||
datafile.refresh_from_disk(source_root=local_path.name)
|
datafile.refresh_from_disk(source_root=local_path)
|
||||||
datafile.full_clean()
|
datafile.full_clean()
|
||||||
new_datafiles.append(datafile)
|
new_datafiles.append(datafile)
|
||||||
created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100))
|
created_count = len(DataFile.objects.bulk_create(new_datafiles, batch_size=100))
|
||||||
logger.debug(f"Created {created_count} data files")
|
logger.debug(f"Created {created_count} data files")
|
||||||
|
|
||||||
# Update status & last_synced time
|
# Update status & last_synced time
|
||||||
self.status = DataSourceStatusChoices.COMPLETED
|
self.status = DataSourceStatusChoices.COMPLETED
|
||||||
self.last_updated = timezone.now()
|
self.last_updated = timezone.now()
|
||||||
DataSource.objects.filter(pk=self.pk).update(status=self.status, last_updated=self.last_updated)
|
DataSource.objects.filter(pk=self.pk).update(status=self.status, last_updated=self.last_updated)
|
||||||
|
|
||||||
local_path.cleanup()
|
|
||||||
|
|
||||||
def fetch(self):
|
|
||||||
"""
|
|
||||||
Replicate the file structure from the remote data source and return the local path.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Fetching source data for {self} ({self.get_type_display()})")
|
|
||||||
try:
|
|
||||||
fetch_method = getattr(self, f'fetch_{self.type}')
|
|
||||||
except AttributeError:
|
|
||||||
raise NotImplemented(f"fetch() not yet supported for {self.get_type_display()} data sources")
|
|
||||||
|
|
||||||
return fetch_method()
|
|
||||||
|
|
||||||
def fetch_local(self, path):
|
|
||||||
"""
|
|
||||||
Skip fetching for local paths; return the source path directly.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Data source type is local; skipping fetch")
|
|
||||||
local_path = urlparse(self.url).path
|
|
||||||
|
|
||||||
return FakeTempDirectory(local_path)
|
|
||||||
|
|
||||||
def fetch_git(self):
|
|
||||||
"""
|
|
||||||
Perform a shallow clone of the remote repository using the `git` executable.
|
|
||||||
"""
|
|
||||||
local_path = tempfile.TemporaryDirectory()
|
|
||||||
|
|
||||||
# Add authentication credentials to URL (if specified)
|
|
||||||
if self.username and self.password:
|
|
||||||
url_components = list(urlparse(self.url))
|
|
||||||
# Prepend username & password to netloc
|
|
||||||
url_components[1] = quote(f'{self.username}@{self.password}:') + url_components[1]
|
|
||||||
url = urlunparse(url_components)
|
|
||||||
else:
|
|
||||||
url = self.url
|
|
||||||
|
|
||||||
# Compile git arguments
|
|
||||||
args = ['git', 'clone', '--depth', '1']
|
|
||||||
if self.git_branch:
|
|
||||||
args.extend(['--branch', self.git_branch])
|
|
||||||
args.extend([url, local_path.name])
|
|
||||||
|
|
||||||
# Prep environment variables
|
|
||||||
env_vars = {}
|
|
||||||
if settings.HTTP_PROXIES and self.url_scheme in ('http', 'https'):
|
|
||||||
env_vars['http_proxy'] = settings.HTTP_PROXIES.get(self.url_scheme)
|
|
||||||
|
|
||||||
logger.debug(f"Cloning git repo: {' '.join(args)}")
|
|
||||||
try:
|
|
||||||
subprocess.run(args, check=True, capture_output=True, env=env_vars)
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
raise SyncError(
|
|
||||||
f"Fetching remote data failed: {e.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return local_path
|
|
||||||
|
|
||||||
def _walk(self, root):
|
def _walk(self, root):
|
||||||
"""
|
"""
|
||||||
Return a set of all non-excluded files within the root path.
|
Return a set of all non-excluded files within the root path.
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
__all__ = (
|
|
||||||
'FakeTempDirectory',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FakeTempDirectory:
|
|
||||||
"""
|
|
||||||
Mimic tempfile.TemporaryDirectory to represent a real local path.
|
|
||||||
"""
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
pass
|
|
Loading…
Reference in New Issue
Block a user