[Python-Dev] Proof of the pudding: str.partition()

Raymond Hettinger raymond.hettinger at verizon.net
Mon Aug 29 07:48:57 CEST 2005


As promised, here is a full set of real-world comparative code
transformations using str.partition().  The patch isn't intended to be
applied; rather, it is here to test/demonstrate whether the new
construct offers benefits under a variety of use cases.

Overall, I found that partition() usefully encapsulated commonly
occurring low-level programming patterns.  In most cases, it completely
eliminated the need for slicing and indices.  In several cases, code was
simplified dramatically; in some, the simplification was minor; and in a
few cases, the complexity was about the same.  No cases were made worse.

Most patterns using str.find() directly translated into an equivalent
using partition.  The only awkwardness that arose was in cases where the
original code had a test like, "if s.find(pat) > 0".  That case
translated to a double-term test, "if found and head".  Also, some
pieces of code needed a tail that included the separator.  That need was
met by inserting a line like "tail = sep + tail".  And that solution led
to a minor naming discomfort for the middle term of the result tuple, it
was being used as both a Boolean found flag and as a string containing
the separator (hence conflicting the choice of names between "found" and
"sep").

In most cases, there was some increase in efficiency resulting fewer
total steps and tests, and from eliminating double searches.  However,
in a few cases, the new code was less efficient because the fragment
only needed either the head or tail but not both as provided by
partition().

In every case, the code was clearer after the transformation.  Also,
none of the transformations required str.partition() to be used in a
tricky way.  In contrast, I found many contortions using str.find()
where I had to diagram every possible path to understand what the code
was trying to do or to assure myself that it worked.

The new methods excelled at reducing cyclomatic complexity by
eliminating conditional paths.  The methods were especially helpful in
the context of multiple finds (i.e. split at the leftmost colon if
present within a group following the rightmost forward slash if
present).  In several cases, the replaced code exactly matched the pure
python version of str.partition() -- this confirms that people are
routinely writing multi-step low-level in-line code that duplicates was
str.partition() does in a single step.

The more complex transformations were handled by first figuring out
exactly was the original code did under all possible cases and then
writing the partition() version to match that spec.  The lesson was that
it is much easier to program from scratch using partition() than it is
to code using find().  The new method more naturally expresses a series
of parsing steps interleaved with other code.

With further ado, here are the comparative code fragments:

Index: CGIHTTPServer.py
===================================================================
*** 106,121 ****
      def run_cgi(self):
          """Execute a CGI script."""
          dir, rest = self.cgi_info
!         i = rest.rfind('?')
!         if i >= 0:
!             rest, query = rest[:i], rest[i+1:]
!         else:
!             query = ''
!         i = rest.find('/')
!         if i >= 0:
!             script, rest = rest[:i], rest[i:]
!         else:
!             script, rest = rest, ''
          scriptname = dir + '/' + script
          scriptfile = self.translate_path(scriptname)
          if not os.path.exists(scriptfile):
--- 106,113 ----
      def run_cgi(self):
          """Execute a CGI script."""
          dir, rest = self.cgi_info
!         rest, _, query = rest.rpartition('?')
!         script, _, rest = rest.partition('/')
          scriptname = dir + '/' + script
          scriptfile = self.translate_path(scriptname)
          if not os.path.exists(scriptfile):
Index: ConfigParser.py
===================================================================
*** 599,612 ****
          if depth > MAX_INTERPOLATION_DEPTH:
              raise InterpolationDepthError(option, section, rest)
          while rest:
!             p = rest.find("%")
!             if p < 0:
!                 accum.append(rest)
                  return
!             if p > 0:
!                 accum.append(rest[:p])
!                 rest = rest[p:]
!             # p is no longer used
              c = rest[1:2]
              if c == "%":
                  accum.append("%")
--- 599,611 ----
          if depth > MAX_INTERPOLATION_DEPTH:
              raise InterpolationDepthError(option, section, rest)
          while rest:
!             head, sep, rest = rest.partition("%")
!             if not sep:
!                 accum.append(head)
                  return
!             rest = sep + rest
!             if found and head:
!                 accum.append(head)
              c = rest[1:2]
              if c == "%":
                  accum.append("%")
Index: cgi.py
===================================================================
*** 337,346 ****
      key = plist.pop(0).lower()
      pdict = {}
      for p in plist:
!         i = p.find('=')
!         if i >= 0:
!             name = p[:i].strip().lower()
!             value = p[i+1:].strip()
              if len(value) >= 2 and value[0] == value[-1] == '"':
                  value = value[1:-1]
                  value = value.replace('\\\\', '\\').replace('\\"',
'"')
--- 337,346 ----
      key = plist.pop(0).lower()
      pdict = {}
      for p in plist:
!         name, found, value = p.partition('=')
!         if found:
!             name = name.strip().lower()
!             value = value.strip()
              if len(value) >= 2 and value[0] == value[-1] == '"':
                  value = value[1:-1]
                  value = value.replace('\\\\', '\\').replace('\\"',
'"')
Index: cookielib.py
===================================================================
*** 610,618 ****
  
  def request_port(request):
      host = request.get_host()
!     i = host.find(':')
!     if i >= 0:
!         port = host[i+1:]
          try:
              int(port)
          except ValueError:
--- 610,617 ----
  
  def request_port(request):
      host = request.get_host()
!     _, sep, port = host.partition(':')
!     if sep:
          try:
              int(port)
          except ValueError:
***************
*** 670,681 ****
      '.local'
  
      """
!     i = h.find(".")
!     if i >= 0:
!         #a = h[:i]  # this line is only here to show what a is
!         b = h[i+1:]
!         i = b.find(".")
!         if is_HDN(h) and (i >= 0 or b == "local"):
              return "."+b
      return h
  
--- 669,677 ----
      '.local'
  
      """
!     a, found, b = h.partition('.')
!     if found:
!         if is_HDN(h) and ('.' in b or b == "local"):
              return "."+b
      return h
  
***************
*** 1451,1463 ****
          else:
              path_specified = False
              path = request_path(request)
!             i = path.rfind("/")
!             if i != -1:
                  if version == 0:
                      # Netscape spec parts company from reality here
!                     path = path[:i]
                  else:
!                     path = path[:i+1]
              if len(path) == 0: path = "/"
  
          # set default domain
--- 1447,1459 ----
          else:
              path_specified = False
              path = request_path(request)
!             head, sep, _ = path.rpartition('/')
!             if sep:
                  if version == 0:
                      # Netscape spec parts company from reality here
!                     path = head
                  else:
!                     path = head + sep
              if len(path) == 0: path = "/"
  
          # set default domain
Index: gopherlib.py
===================================================================
*** 57,65 ****
      """Send a selector to a given host and port, return a file with
the reply."""
      import socket
      if not port:
!         i = host.find(':')
!         if i >= 0:
!             host, port = host[:i], int(host[i+1:])
      if not port:
          port = DEF_PORT
      elif type(port) == type(''):
--- 57,65 ----
      """Send a selector to a given host and port, return a file with
the reply."""
      import socket
      if not port:
!         head, found, tail = host.partition(':')
!         if found:
!             host, port = head, int(tail)
      if not port:
          port = DEF_PORT
      elif type(port) == type(''):
Index: httplib.py
===================================================================
*** 490,498 ****
          while True:
              if chunk_left is None:
                  line = self.fp.readline()
!                 i = line.find(';')
!                 if i >= 0:
!                     line = line[:i] # strip chunk-extensions
                  chunk_left = int(line, 16)
                  if chunk_left == 0:
                      break
--- 490,496 ----
          while True:
              if chunk_left is None:
                  line = self.fp.readline()
!                 line, _, _ = line.partition(';')  # strip
chunk-extensions
                  chunk_left = int(line, 16)
                  if chunk_left == 0:
                      break
***************
*** 586,599 ****
  
      def _set_hostport(self, host, port):
          if port is None:
!             i = host.rfind(':')
!             j = host.rfind(']')         # ipv6 addresses have [...]
!             if i > j:
                  try:
!                     port = int(host[i+1:])
                  except ValueError:
!                     raise InvalidURL("nonnumeric port: '%s'" %
host[i+1:])
!                 host = host[:i]
              else:
                  port = self.default_port
              if host and host[0] == '[' and host[-1] == ']':
--- 584,595 ----
  
      def _set_hostport(self, host, port):
          if port is None:
!             host, _, port = host.rpartition(':')
!             if ']' not in port:         # ipv6 addresses have [...]
                  try:
!                     port = int(port)
                  except ValueError:
!                     raise InvalidURL("nonnumeric port: '%s'" % port)
              else:
                  port = self.default_port
              if host and host[0] == '[' and host[-1] == ']':
***************
*** 976,998 ****
          L = [self._buf]
          self._buf = ''
          while 1:
!             i = L[-1].find("\n")
!             if i >= 0:
                  break
              s = self._read()
              if s == '':
                  break
              L.append(s)
!         if i == -1:
              # loop exited because there is no more data
              return "".join(L)
          else:
!             all = "".join(L)
!             # XXX could do enough bookkeeping not to do a 2nd search
!             i = all.find("\n") + 1
!             line = all[:i]
!             self._buf = all[i:]
!             return line
  
      def readlines(self, sizehint=0):
          total = 0
--- 972,990 ----
          L = [self._buf]
          self._buf = ''
          while 1:
!             head, found, tail = L[-1].partition('\n')
!             if found:
                  break
              s = self._read()
              if s == '':
                  break
              L.append(s)
!         if not found:
              # loop exited because there is no more data
              return "".join(L)
          else:
!             self._buf = found + tail
!             return "".join(L) + head
  
      def readlines(self, sizehint=0):
          total = 0
Index: ihooks.py
===================================================================
*** 426,438 ****
          return None
  
      def find_head_package(self, parent, name):
!         if '.' in name:
!             i = name.find('.')
!             head = name[:i]
!             tail = name[i+1:]
!         else:
!             head = name
!             tail = ""
          if parent:
              qname = "%s.%s" % (parent.__name__, head)
          else:
--- 426,432 ----
          return None
  
      def find_head_package(self, parent, name):
!         head, _, tail = name.partition('.')
          if parent:
              qname = "%s.%s" % (parent.__name__, head)
          else:
***************
*** 449,457 ****
      def load_tail(self, q, tail):
          m = q
          while tail:
!             i = tail.find('.')
!             if i < 0: i = len(tail)
!             head, tail = tail[:i], tail[i+1:]
              mname = "%s.%s" % (m.__name__, head)
              m = self.import_it(head, mname, m)
              if not m:
--- 443,449 ----
      def load_tail(self, q, tail):
          m = q
          while tail:
!             head, _, tail = tail.partition('.')
              mname = "%s.%s" % (m.__name__, head)
              m = self.import_it(head, mname, m)
              if not m:
Index: locale.py
===================================================================
*** 98,106 ****
      seps = 0
      spaces = ""
      if s[-1] == ' ':
!         sp = s.find(' ')
!         spaces = s[sp:]
!         s = s[:sp]
      while s and grouping:
          # if grouping is -1, we are done
          if grouping[0]==CHAR_MAX:
--- 98,105 ----
      seps = 0
      spaces = ""
      if s[-1] == ' ':
!         spaces, sep, tail = s.partition(' ')
!         s = sep + tail
      while s and grouping:
          # if grouping is -1, we are done
          if grouping[0]==CHAR_MAX:
***************
*** 148,156 ****
          # so, kill as much spaces as there where separators.
          # Leading zeroes as fillers are not yet dealt with, as it is
          # not clear how they should interact with grouping.
!         sp = result.find(" ")
!         if sp==-1:break
!         result = result[:sp]+result[sp+1:]
          seps -= 1
  
      return result
--- 147,156 ----
          # so, kill as much spaces as there where separators.
          # Leading zeroes as fillers are not yet dealt with, as it is
          # not clear how they should interact with grouping.
!         head, found, tail = result.partition(' ')
!         if not found:
!             break
!         result = head + tail
          seps -= 1
  
      return result
Index: mailcap.py
===================================================================
*** 105,117 ****
      key, view, rest = fields[0], fields[1], fields[2:]
      fields = {'view': view}
      for field in rest:
!         i = field.find('=')
!         if i < 0:
!             fkey = field
!             fvalue = ""
!         else:
!             fkey = field[:i].strip()
!             fvalue = field[i+1:].strip()
          if fkey in fields:
              # Ignore it
              pass
--- 105,113 ----
      key, view, rest = fields[0], fields[1], fields[2:]
      fields = {'view': view}
      for field in rest:
!         fkey, found, fvalue = field.partition('=')
!         fkey = fkey.strip()
!         fvalue = fvalue.strip()
          if fkey in fields:
              # Ignore it
              pass
Index: mhlib.py
===================================================================
*** 356,364 ****
          if seq == 'all':
              return all
          # Test for X:Y before X-Y because 'seq:-n' matches both
!         i = seq.find(':')
!         if i >= 0:
!             head, dir, tail = seq[:i], '', seq[i+1:]
              if tail[:1] in '-+':
                  dir, tail = tail[:1], tail[1:]
              if not isnumeric(tail):
--- 356,364 ----
          if seq == 'all':
              return all
          # Test for X:Y before X-Y because 'seq:-n' matches both
!         head, found, tail = seq.partition(':')
!         if found:
!             dir = ''
              if tail[:1] in '-+':
                  dir, tail = tail[:1], tail[1:]
              if not isnumeric(tail):
***************
*** 394,403 ****
                      i = bisect(all, anchor-1)
                      return all[i:i+count]
          # Test for X-Y next
!         i = seq.find('-')
!         if i >= 0:
!             begin = self._parseindex(seq[:i], all)
!             end = self._parseindex(seq[i+1:], all)
              i = bisect(all, begin-1)
              j = bisect(all, end)
              r = all[i:j]
--- 394,403 ----
                      i = bisect(all, anchor-1)
                      return all[i:i+count]
          # Test for X-Y next
!         head, found, tail = seq.find('-')
!         if found:
!             begin = self._parseindex(head, all)
!             end = self._parseindex(tail, all)
              i = bisect(all, begin-1)
              j = bisect(all, end)
              r = all[i:j]
Index: modulefinder.py
===================================================================
*** 140,148 ****
              assert caller is parent
              self.msgout(4, "determine_parent ->", parent)
              return parent
!         if '.' in pname:
!             i = pname.rfind('.')
!             pname = pname[:i]
              parent = self.modules[pname]
              assert parent.__name__ == pname
              self.msgout(4, "determine_parent ->", parent)
--- 140,147 ----
              assert caller is parent
              self.msgout(4, "determine_parent ->", parent)
              return parent
!         pname, found, _ = pname.rpartition('.')
!         if found:
              parent = self.modules[pname]
              assert parent.__name__ == pname
              self.msgout(4, "determine_parent ->", parent)
***************
*** 152,164 ****
  
      def find_head_package(self, parent, name):
          self.msgin(4, "find_head_package", parent, name)
!         if '.' in name:
!             i = name.find('.')
!             head = name[:i]
!             tail = name[i+1:]
!         else:
!             head = name
!             tail = ""
          if parent:
              qname = "%s.%s" % (parent.__name__, head)
          else:
--- 151,157 ----
  
      def find_head_package(self, parent, name):
          self.msgin(4, "find_head_package", parent, name)
!         head, _, tail = name.partition('.')
          if parent:
              qname = "%s.%s" % (parent.__name__, head)
          else:
Index: pdb.py
===================================================================
*** 189,200 ****
          # split into ';;' separated commands
          # unless it's an alias command
          if args[0] != 'alias':
!             marker = line.find(';;')
!             if marker >= 0:
!                 # queue up everything after marker
!                 next = line[marker+2:].lstrip()
                  self.cmdqueue.append(next)
!                 line = line[:marker].rstrip()
          return line
  
      # Command definitions, called by cmdloop()
--- 189,200 ----
          # split into ';;' separated commands
          # unless it's an alias command
          if args[0] != 'alias':
!             line, found, next = line.partition(';;')
!             if found:
!                 # queue up everything after command separator
!                 next = next.lstrip()
                  self.cmdqueue.append(next)
!                 line = line.rstrip()
          return line
  
      # Command definitions, called by cmdloop()
***************
*** 217,232 ****
          filename = None
          lineno = None
          cond = None
!         comma = arg.find(',')
!         if comma > 0:
              # parse stuff after comma: "condition"
!             cond = arg[comma+1:].lstrip()
!             arg = arg[:comma].rstrip()
          # parse stuff before comma: [filename:]lineno | function
-         colon = arg.rfind(':')
          funcname = None
!         if colon >= 0:
!             filename = arg[:colon].rstrip()
              f = self.lookupmodule(filename)
              if not f:
                  print '*** ', repr(filename),
--- 217,232 ----
          filename = None
          lineno = None
          cond = None
!         arg, found, cond = arg.partition(',')
!         if found and arg:
              # parse stuff after comma: "condition"
!             arg = arg.rstrip()
!             cond = cond.lstrip()
          # parse stuff before comma: [filename:]lineno | function
          funcname = None
!         filename, found, arg = arg.rpartition(':')
!         if found:
!             filename = filename.rstrip()
              f = self.lookupmodule(filename)
              if not f:
                  print '*** ', repr(filename),
***************
*** 234,240 ****
                  return
              else:
                  filename = f
!             arg = arg[colon+1:].lstrip()
              try:
                  lineno = int(arg)
              except ValueError, msg:
--- 234,240 ----
                  return
              else:
                  filename = f
!             arg = arg.lstrip()
              try:
                  lineno = int(arg)
              except ValueError, msg:
***************
*** 437,445 ****
              return
          if ':' in arg:
              # Make sure it works for "clear C:\foo\bar.py:12"
!             i = arg.rfind(':')
!             filename = arg[:i]
!             arg = arg[i+1:]
              try:
                  lineno = int(arg)
              except:
--- 437,443 ----
              return
          if ':' in arg:
              # Make sure it works for "clear C:\foo\bar.py:12"
!             filename, _, arg = arg.rpartition(':')
              try:
                  lineno = int(arg)
              except:
Index: rfc822.py
===================================================================
*** 197,205 ****
          You may override this method in order to use Message parsing
on tagged
          data in RFC 2822-like formats with special header formats.
          """
!         i = line.find(':')
!         if i > 0:
!             return line[:i].lower()
          return None
  
      def islast(self, line):
--- 197,205 ----
          You may override this method in order to use Message parsing
on tagged
          data in RFC 2822-like formats with special header formats.
          """
!         head, found, tail = line.partition(':')
!         if found and head:
!             return head.lower()
          return None
  
      def islast(self, line):
***************
*** 340,348 ****
              else:
                  if raw:
                      raw.append(', ')
!                 i = h.find(':')
!                 if i > 0:
!                     addr = h[i+1:]
                  raw.append(addr)
          alladdrs = ''.join(raw)
          a = AddressList(alladdrs)
--- 340,348 ----
              else:
                  if raw:
                      raw.append(', ')
!                 head, found, tail = h.partition(':')
!                 if found and head:
!                     addr = tail
                  raw.append(addr)
          alladdrs = ''.join(raw)
          a = AddressList(alladdrs)
***************
*** 859,867 ****
              data = stuff + data[1:]
      if len(data) == 4:
          s = data[3]
!         i = s.find('+')
!         if i > 0:
!             data[3:] = [s[:i], s[i+1:]]
          else:
              data.append('') # Dummy tz
      if len(data) < 5:
--- 859,867 ----
              data = stuff + data[1:]
      if len(data) == 4:
          s = data[3]
!         head, found, tail = s.partition('+')
!         if found and head:
!             data[3:] = [head, tail]
          else:
              data.append('') # Dummy tz
      if len(data) < 5:
Index: robotparser.py
===================================================================
*** 104,112 ****
                      entry = Entry()
                      state = 0
              # remove optional comment and strip line
!             i = line.find('#')
!             if i>=0:
!                 line = line[:i]
              line = line.strip()
              if not line:
                  continue
--- 104,110 ----
                      entry = Entry()
                      state = 0
              # remove optional comment and strip line
!             line, _, _ = line.partition('#')
              line = line.strip()
              if not line:
                  continue
Index: smtpd.py
===================================================================
*** 144,156 ****
                  self.push('500 Error: bad syntax')
                  return
              method = None
!             i = line.find(' ')
!             if i < 0:
!                 command = line.upper()
                  arg = None
              else:
!                 command = line[:i].upper()
!                 arg = line[i+1:].strip()
              method = getattr(self, 'smtp_' + command, None)
              if not method:
                  self.push('502 Error: command "%s" not implemented' %
command)
--- 144,155 ----
                  self.push('500 Error: bad syntax')
                  return
              method = None
!             command, found, arg = line.partition(' ')
!             command = command.upper()            
!             if not found:
                  arg = None
              else:
!                 arg = tail.strip()
              method = getattr(self, 'smtp_' + command, None)
              if not method:
                  self.push('502 Error: command "%s" not implemented' %
command)
***************
*** 495,514 ****
          usage(1, 'Invalid arguments: %s' % COMMASPACE.join(args))
  
      # split into host/port pairs
!     i = localspec.find(':')
!     if i < 0:
          usage(1, 'Bad local spec: %s' % localspec)
!     options.localhost = localspec[:i]
      try:
!         options.localport = int(localspec[i+1:])
      except ValueError:
          usage(1, 'Bad local port: %s' % localspec)
!     i = remotespec.find(':')
!     if i < 0:
          usage(1, 'Bad remote spec: %s' % remotespec)
!     options.remotehost = remotespec[:i]
      try:
!         options.remoteport = int(remotespec[i+1:])
      except ValueError:
          usage(1, 'Bad remote port: %s' % remotespec)
      return options
--- 494,513 ----
          usage(1, 'Invalid arguments: %s' % COMMASPACE.join(args))
  
      # split into host/port pairs
!     head, found, tail = localspec.partition(':')
!     if not found:
          usage(1, 'Bad local spec: %s' % localspec)
!     options.localhost = head
      try:
!         options.localport = int(tail)
      except ValueError:
          usage(1, 'Bad local port: %s' % localspec)
!     head, found, tail = remotespec.partition(':')        
!     if not found:
          usage(1, 'Bad remote spec: %s' % remotespec)
!     options.remotehost = head
      try:
!         options.remoteport = int(tail)
      except ValueError:
          usage(1, 'Bad remote port: %s' % remotespec)
      return options
Index: smtplib.py
===================================================================
*** 276,284 ****
  
          """
          if not port and (host.find(':') == host.rfind(':')):
!             i = host.rfind(':')
!             if i >= 0:
!                 host, port = host[:i], host[i+1:]
                  try: port = int(port)
                  except ValueError:
                      raise socket.error, "nonnumeric port"
--- 276,283 ----
  
          """
          if not port and (host.find(':') == host.rfind(':')):
!             host, found, port = host.rpartition(':')
!             if found:
                  try: port = int(port)
                  except ValueError:
                      raise socket.error, "nonnumeric port"
Index: urllib2.py
===================================================================
*** 289,301 ****
      def add_handler(self, handler):
          added = False
          for meth in dir(handler):
!             i = meth.find("_")
!             protocol = meth[:i]
!             condition = meth[i+1:]
! 
              if condition.startswith("error"):
!                 j = condition.find("_") + i + 1
!                 kind = meth[j+1:]
                  try:
                      kind = int(kind)
                  except ValueError:
--- 289,297 ----
      def add_handler(self, handler):
          added = False
          for meth in dir(handler):
!             protocol, _, condition = meth.partition('_')
              if condition.startswith("error"):
!                 _, _, kind = condition.partition('_')
                  try:
                      kind = int(kind)
                  except ValueError:
Index: zipfile.py
===================================================================
*** 117,125 ****
          self.orig_filename = filename   # Original file name in
archive
  # Terminate the file name at the first null byte.  Null bytes in file
  # names are used as tricks by viruses in archives.
!         null_byte = filename.find(chr(0))
!         if null_byte >= 0:
!             filename = filename[0:null_byte]
  # This is used to ensure paths in generated ZIP files always use
  # forward slashes as the directory separator, as required by the
  # ZIP format specification.
--- 117,123 ----
          self.orig_filename = filename   # Original file name in
archive
  # Terminate the file name at the first null byte.  Null bytes in file
  # names are used as tricks by viruses in archives.
!         filename, _, _ = filename.partition(chr(0))
  # This is used to ensure paths in generated ZIP files always use
  # forward slashes as the directory separator, as required by the
  # ZIP format specification.



More information about the Python-Dev mailing list