[Python-checkins] cpython: Issue 14814: Make the ipaddress code easier to follow by using newer language

nick.coghlan python-checkins at python.org
Sat Jul 7 13:43:42 CEST 2012


http://hg.python.org/cpython/rev/af4ae710daf3
changeset:   77965:af4ae710daf3
user:        Nick Coghlan <ncoghlan at gmail.com>
date:        Sat Jul 07 21:43:30 2012 +1000
summary:
  Issue 14814: Make the ipaddress code easier to follow by using newer language features (patch by Serhiy Storchaka)

files:
  Lib/ipaddress.py |  156 ++++++++++++++--------------------
  1 files changed, 66 insertions(+), 90 deletions(-)


diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py
--- a/Lib/ipaddress.py
+++ b/Lib/ipaddress.py
@@ -214,8 +214,10 @@
     if number == 0:
         return bits
     for i in range(bits):
-        if (number >> i) % 2:
+        if (number >> i) & 1:
             return i
+    # All bits of interest were zero, even if there are more in the number
+    return bits
 
 
 def summarize_address_range(first, last):
@@ -263,20 +265,13 @@
     first_int = first._ip
     last_int = last._ip
     while first_int <= last_int:
-        nbits = _count_righthand_zero_bits(first_int, ip_bits)
-        current = None
-        while nbits >= 0:
-            addend = 2**nbits - 1
-            current = first_int + addend
-            nbits -= 1
-            if current <= last_int:
-                break
-        prefix = _get_prefix_length(first_int, current, ip_bits)
-        net = ip('%s/%d' % (first, prefix))
+        nbits = min(_count_righthand_zero_bits(first_int, ip_bits),
+                    (last_int - first_int + 1).bit_length() - 1)
+        net = ip('%s/%d' % (first, ip_bits - nbits))
         yield net
-        if current == ip._ALL_ONES:
+        first_int += 1 << nbits
+        if first_int - 1 == ip._ALL_ONES:
             break
-        first_int = current + 1
         first = first.__class__(first_int)
 
 
@@ -304,26 +299,28 @@
         passed.
 
     """
-    ret_array = []
-    optimized = False
+    while True:
+        last_addr = None
+        ret_array = []
+        optimized = False
 
-    for cur_addr in addresses:
-        if not ret_array:
-            ret_array.append(cur_addr)
-            continue
-        if (cur_addr.network_address >= ret_array[-1].network_address and
-            cur_addr.broadcast_address <= ret_array[-1].broadcast_address):
-            optimized = True
-        elif cur_addr == list(ret_array[-1].supernet().subnets())[1]:
-            ret_array.append(ret_array.pop().supernet())
-            optimized = True
-        else:
-            ret_array.append(cur_addr)
+        for cur_addr in addresses:
+            if not ret_array:
+                last_addr = cur_addr
+                ret_array.append(cur_addr)
+            elif (cur_addr.network_address >= last_addr.network_address and
+                cur_addr.broadcast_address <= last_addr.broadcast_address):
+                optimized = True
+            elif cur_addr == list(last_addr.supernet().subnets())[1]:
+                ret_array[-1] = last_addr = last_addr.supernet()
+                optimized = True
+            else:
+                last_addr = cur_addr
+                ret_array.append(cur_addr)
 
-    if optimized:
-        return _collapse_addresses_recursive(ret_array)
-
-    return ret_array
+        addresses = ret_array
+        if not optimized:
+            return addresses
 
 
 def collapse_addresses(addresses):
@@ -452,13 +449,7 @@
             An integer, the prefix length.
 
         """
-        while mask:
-            if ip_int & 1 == 1:
-                break
-            ip_int >>= 1
-            mask -= 1
-
-        return mask
+        return mask - _count_righthand_zero_bits(ip_int, mask)
 
     def _ip_string_from_prefix(self, prefixlen=None):
         """Turn a prefix length into a dotted decimal string.
@@ -597,18 +588,16 @@
         or broadcast addresses.
 
         """
-        cur = int(self.network_address) + 1
-        bcast = int(self.broadcast_address) - 1
-        while cur <= bcast:
-            cur += 1
-            yield self._address_class(cur - 1)
+        network = int(self.network_address)
+        broadcast = int(self.broadcast_address)
+        for x in range(network + 1, broadcast):
+            yield self._address_class(x)
 
     def __iter__(self):
-        cur = int(self.network_address)
-        bcast = int(self.broadcast_address)
-        while cur <= bcast:
-            cur += 1
-            yield self._address_class(cur - 1)
+        network = int(self.network_address)
+        broadcast = int(self.broadcast_address)
+        for x in range(network, broadcast + 1):
+            yield self._address_class(x)
 
     def __getitem__(self, n):
         network = int(self.network_address)
@@ -998,7 +987,7 @@
     _DECIMAL_DIGITS = frozenset('0123456789')
 
     # the valid octets for host and netmasks. only useful for IPv4.
-    _valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
+    _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0))
 
     def __init__(self, address):
         self._version = 4
@@ -1027,13 +1016,10 @@
         if len(octets) != 4:
             raise AddressValueError("Expected 4 octets in %r" % ip_str)
 
-        packed_ip = 0
-        for oc in octets:
-            try:
-                packed_ip = (packed_ip << 8) | self._parse_octet(oc)
-            except ValueError as exc:
-                raise AddressValueError("%s in %r" % (exc, ip_str)) from None
-        return packed_ip
+        try:
+            return int.from_bytes(map(self._parse_octet, octets), 'big')
+        except ValueError as exc:
+            raise AddressValueError("%s in %r" % (exc, ip_str)) from None
 
     def _parse_octet(self, octet_str):
         """Convert a decimal octet into an integer.
@@ -1075,11 +1061,7 @@
             The IP address as a string in dotted decimal notation.
 
         """
-        octets = []
-        for _ in range(4):
-            octets.insert(0, str(ip_int & 0xFF))
-            ip_int >>= 8
-        return '.'.join(octets)
+        return '.'.join(map(str, ip_int.to_bytes(4, 'big')))
 
     def _is_valid_netmask(self, netmask):
         """Verify that the netmask is valid.
@@ -1095,17 +1077,16 @@
         """
         mask = netmask.split('.')
         if len(mask) == 4:
-            for x in mask:
-                try:
-                    if int(x) in self._valid_mask_octets:
-                        continue
-                except ValueError:
-                    pass
+            try:
+                for x in mask:
+                    if int(x) not in self._valid_mask_octets:
+                        return False
+            except ValueError:
                 # Found something that isn't an integer or isn't valid
                 return False
-            if [y for idx, y in enumerate(mask) if idx > 0 and
-                y > mask[idx - 1]]:
-                return False
+            for idx, y in enumerate(mask):
+                if idx > 0 and y > mask[idx - 1]:
+                    return False
             return True
         try:
             netmask = int(netmask)
@@ -1125,7 +1106,7 @@
         """
         bits = ip_str.split('.')
         try:
-            parts = [int(x) for x in bits if int(x) in self._valid_mask_octets]
+            parts = [x for x in map(int, bits) if x in self._valid_mask_octets]
         except ValueError:
             return False
         if len(parts) != len(bits):
@@ -1526,14 +1507,14 @@
 
         # Disregarding the endpoints, find '::' with nothing in between.
         # This indicates that a run of zeroes has been skipped.
-        try:
-            skip_index, = (
-                [i for i in range(1, len(parts) - 1) if not parts[i]] or
-                [None])
-        except ValueError:
-            # Can't have more than one '::'
-            msg = "At most one '::' permitted in %r" % ip_str
-            raise AddressValueError(msg) from None
+        skip_index = None
+        for i in range(1, len(parts) - 1):
+            if not parts[i]:
+                if skip_index is not None:
+                    # Can't have more than one '::'
+                    msg = "At most one '::' permitted in %r" % ip_str
+                    raise AddressValueError(msg)
+                skip_index = i
 
         # parts_hi is the number of parts to copy from above/before the '::'
         # parts_lo is the number of parts to copy from below/after the '::'
@@ -1680,9 +1661,7 @@
             raise ValueError('IPv6 address is too large')
 
         hex_str = '%032x' % ip_int
-        hextets = []
-        for x in range(0, 32, 4):
-            hextets.append('%x' % int(hex_str[x:x+4], 16))
+        hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)]
 
         hextets = self._compress_hextets(hextets)
         return ':'.join(hextets)
@@ -1705,11 +1684,8 @@
             ip_str = str(self)
 
         ip_int = self._ip_int_from_string(ip_str)
-        parts = []
-        for i in range(self._HEXTET_COUNT):
-            parts.append('%04x' % (ip_int & 0xFFFF))
-            ip_int >>= 16
-        parts.reverse()
+        hex_str = '%032x' % ip_int
+        parts = [hex_str[x:x+4] for x in range(0, 32, 4)]
         if isinstance(self, (_BaseNetwork, IPv6Interface)):
             return '%s/%d' % (':'.join(parts), self.prefixlen)
         return ':'.join(parts)
@@ -1756,9 +1732,9 @@
                              IPv6Network('FE00::/9')]
 
         if isinstance(self, _BaseAddress):
-            return len([x for x in reserved_networks if self in x]) > 0
-        return len([x for x in reserved_networks if self.network_address in x
-                    and self.broadcast_address in x]) > 0
+            return any(self in x for x in reserved_networks)
+        return any(self.network_address in x and self.broadcast_address in x
+                   for x in reserved_networks)
 
     @property
     def is_link_local(self):

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list