diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index 352c9b87c46..9a1ba728f2b 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -1250,7 +1250,9 @@ class IPv4Address(_BaseV4, _BaseAddress): return # Constructing from a packed address - if isinstance(address, bytes) and len(address) == 4: + if isinstance(address, bytes): + if len(address) != 4: + raise AddressValueError(address) self._ip = struct.unpack('!I', address)[0] return @@ -1379,7 +1381,9 @@ class IPv4Network(_BaseV4, _BaseNetwork): _BaseNetwork.__init__(self, address) # Constructing from a packed address - if isinstance(address, bytes) and len(address) == 4: + if isinstance(address, bytes): + if len(address) != 4: + raise AddressValueError(address) self.network_address = IPv4Address( struct.unpack('!I', address)[0]) self._prefixlen = self._max_prefixlen @@ -1864,7 +1868,9 @@ class IPv6Address(_BaseV6, _BaseAddress): return # Constructing from a packed address - if isinstance(address, bytes) and len(address) == 16: + if isinstance(address, bytes): + if len(address) != 16: + raise AddressValueError(address) tmp = struct.unpack('!QQ', address) self._ip = (tmp[0] << 64) | tmp[1] return @@ -1996,7 +2002,9 @@ class IPv6Network(_BaseV6, _BaseNetwork): return # Constructing from a packed address - if isinstance(address, bytes) and len(address) == 16: + if isinstance(address, bytes): + if len(address) != 16: + raise AddressValueError(address) tmp = struct.unpack('!QQ', address) self.network_address = IPv6Address((tmp[0] << 64) | tmp[1]) self._prefixlen = self._max_prefixlen diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index c9ced59c234..5cd2ad4d198 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -8,10 +8,6 @@ import unittest import ipaddress -# Compatibility function to cast str to bytes objects -_cb = lambda bytestr: bytes(bytestr, 'charmap') - - class IpaddrUnitTest(unittest.TestCase): def setUp(self): @@ -267,25 +263,36 @@ class IpaddrUnitTest(unittest.TestCase): 6) def testIpFromPacked(self): - ip = ipaddress.ip_network - + address = ipaddress.ip_address self.assertEqual(self.ipv4_interface._ip, - ipaddress.ip_interface(_cb('\x01\x02\x03\x04'))._ip) - self.assertEqual(ip('255.254.253.252'), - ip(_cb('\xff\xfe\xfd\xfc'))) - self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 3)) - self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 5)) + ipaddress.ip_interface(b'\x01\x02\x03\x04')._ip) + self.assertEqual(address('255.254.253.252'), + address(b'\xff\xfe\xfd\xfc')) self.assertEqual(self.ipv6_interface.ip, ipaddress.ip_interface( - _cb('\x20\x01\x06\x58\x02\x2a\xca\xfe' - '\x02\x00\x00\x00\x00\x00\x00\x01')).ip) - self.assertEqual(ip('ffff:2:3:4:ffff::'), - ip(_cb('\xff\xff\x00\x02\x00\x03\x00\x04' + - '\xff\xff' + '\x00' * 6))) - self.assertEqual(ip('::'), - ip(_cb('\x00' * 16))) - self.assertRaises(ValueError, ip, _cb('\x00' * 15)) - self.assertRaises(ValueError, ip, _cb('\x00' * 17)) + b'\x20\x01\x06\x58\x02\x2a\xca\xfe' + b'\x02\x00\x00\x00\x00\x00\x00\x01').ip) + self.assertEqual(address('ffff:2:3:4:ffff::'), + address(b'\xff\xff\x00\x02\x00\x03\x00\x04' + + b'\xff\xff' + b'\x00' * 6)) + self.assertEqual(address('::'), + address(b'\x00' * 16)) + + def testIpFromPackedErrors(self): + def assertInvalidPackedAddress(f, length): + self.assertRaises(ValueError, f, b'\x00' * length) + assertInvalidPackedAddress(ipaddress.ip_address, 3) + assertInvalidPackedAddress(ipaddress.ip_address, 5) + assertInvalidPackedAddress(ipaddress.ip_address, 15) + assertInvalidPackedAddress(ipaddress.ip_address, 17) + assertInvalidPackedAddress(ipaddress.ip_interface, 3) + assertInvalidPackedAddress(ipaddress.ip_interface, 5) + assertInvalidPackedAddress(ipaddress.ip_interface, 15) + assertInvalidPackedAddress(ipaddress.ip_interface, 17) + assertInvalidPackedAddress(ipaddress.ip_network, 3) + assertInvalidPackedAddress(ipaddress.ip_network, 5) + assertInvalidPackedAddress(ipaddress.ip_network, 15) + assertInvalidPackedAddress(ipaddress.ip_network, 17) def testGetIp(self): self.assertEqual(int(self.ipv4_interface.ip), 16909060) @@ -893,17 +900,17 @@ class IpaddrUnitTest(unittest.TestCase): def testPacked(self): self.assertEqual(self.ipv4_address.packed, - _cb('\x01\x02\x03\x04')) + b'\x01\x02\x03\x04') self.assertEqual(ipaddress.IPv4Interface('255.254.253.252').packed, - _cb('\xff\xfe\xfd\xfc')) + b'\xff\xfe\xfd\xfc') self.assertEqual(self.ipv6_address.packed, - _cb('\x20\x01\x06\x58\x02\x2a\xca\xfe' - '\x02\x00\x00\x00\x00\x00\x00\x01')) + b'\x20\x01\x06\x58\x02\x2a\xca\xfe' + b'\x02\x00\x00\x00\x00\x00\x00\x01') self.assertEqual(ipaddress.IPv6Interface('ffff:2:3:4:ffff::').packed, - _cb('\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff' - + '\x00' * 6)) + b'\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff' + + b'\x00' * 6) self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed, - _cb('\x00' * 6 + '\x00\x01' + '\x00' * 8)) + b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8) def testIpStrFromPrefixlen(self): ipv4 = ipaddress.IPv4Interface('1.2.3.4/24')