Employ urlparse() to strip port numbers from IPs

This commit is contained in:
Jeremy Stretch 2023-12-07 09:13:56 -05:00
parent 904e31b4c5
commit 7828008ff9
2 changed files with 15 additions and 9 deletions

View File

@ -1,4 +1,5 @@
from netaddr import AddrFormatError, IPAddress from netaddr import AddrFormatError, IPAddress
from urllib.parse import urlparse
__all__ = ( __all__ = (
'get_client_ip', 'get_client_ip',
@ -18,15 +19,17 @@ def get_client_ip(request, additional_headers=()):
for header in HTTP_HEADERS: for header in HTTP_HEADERS:
if header in request.META: if header in request.META:
ip = request.META[header].split(',')[0].strip() ip = request.META[header].split(',')[0].strip()
# Check if the IP address is v6 or v4
if ip.count(':') > 1:
client_ip = ip
else:
client_ip = ip.partition(':')[0]
try: try:
return IPAddress(client_ip) return IPAddress(ip)
except (AddrFormatError, ValueError): except AddrFormatError:
raise ValueError(f"Invalid IP address set for {header}: {client_ip}") # Parse the string with urlparse() to remove port number or any other cruft
ip = urlparse(f'//{ip}').hostname
try:
return IPAddress(ip)
except AddrFormatError:
# We did our best
raise ValueError(f"Invalid IP address set for {header}: {ip}")
# Could not determine the client IP address from request headers # Could not determine the client IP address from request headers
return None return None

View File

@ -11,13 +11,16 @@ class GetClientIPTests(TestCase):
def test_ipv4_address(self): def test_ipv4_address(self):
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1') request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1')
self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1')) self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1'))
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1:8080') request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1:8080')
self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1')) self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1'))
def test_ipv6_address(self): def test_ipv6_address(self):
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='2001:db8::8a2e:370:7334') request = self.factory.get('/', HTTP_X_FORWARDED_FOR='2001:db8::8a2e:370:7334')
self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334')) self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='[2001:db8::8a2e:370:7334]')
self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='[2001:db8::8a2e:370:7334]:8080')
self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
def test_invalid_ip_address(self): def test_invalid_ip_address(self):
request = self.factory.get('/', HTTP_X_FORWARDED_FOR='invalid_ip') request = self.factory.get('/', HTTP_X_FORWARDED_FOR='invalid_ip')