[pypy-commit] pypy utf8-unicode2: Simplify iterators. Use iterators consistently when encoding unicode strings

waedt noreply at buildbot.pypy.org
Thu Aug 14 08:17:55 CEST 2014


Author: Tyler Wade <wayedt at gmail.com>
Branch: utf8-unicode2
Changeset: r72787:9f7fc269657f
Date: 2014-08-14 01:02 -0500
http://bitbucket.org/pypy/pypy/changeset/9f7fc269657f/

Log:	Simplify iterators. Use iterators consistently when encoding unicode
	strings

diff --git a/pypy/interpreter/test/test_utf8.py b/pypy/interpreter/test/test_utf8.py
--- a/pypy/interpreter/test/test_utf8.py
+++ b/pypy/interpreter/test/test_utf8.py
@@ -2,8 +2,7 @@
 
 import py
 import sys
-from pypy.interpreter.utf8 import (
-    Utf8Str, Utf8Builder, utf8chr, utf8ord)
+from pypy.interpreter.utf8 import Utf8Str, Utf8Builder, utf8chr, utf8ord
 from rpython.rtyper.lltypesystem import rffi
 from rpython.rtyper.test.test_llinterp import interpret
 
@@ -29,24 +28,8 @@
 def test_iterator():
     s = build_utf8str()
     iter = s.codepoint_iter()
-    assert iter.peek_next() == 0x41
     assert list(iter) == [0x41, 0x10F, 0x20AC, 0x1F63D]
 
-    for i in range(1, 5):
-        iter = s.codepoint_iter()
-        iter.move(i)
-        if i != 4:
-            assert iter.peek_next() == [0x41, 0x10F, 0x20AC, 0x1F63D][i]
-        l = list(iter)
-        assert l == [0x41, 0x10F, 0x20AC, 0x1F63D][i:]
-
-    for i in range(1, 5):
-        iter = s.codepoint_iter()
-        list(iter) # move the iterator to the end
-        iter.move(-i)
-        l = list(iter)
-        assert l == [0x41, 0x10F, 0x20AC, 0x1F63D][4-i:]
-
     iter = s.char_iter()
     l = [s.bytes.decode('utf8') for s in list(iter)]
     if sys.maxunicode < 65536:
@@ -54,26 +37,17 @@
     else:
         assert l == [u'A', u'\u010F', u'\u20AC', u'\U0001F63D']
 
-def test_reverse_iterator():
+def test_new_iterator():
     s = build_utf8str()
-    iter = s.reverse_codepoint_iter()
-    assert iter.peek_next() == 0x1F63D
-    assert list(iter) == [0x1F63D, 0x20AC, 0x10F, 0x41]
+    i = s.iter()
+    while not i.finished():
+        assert utf8ord(s, i.pos()) == i.current()
+        i.move(1)
 
-    for i in range(1, 5):
-        iter = s.reverse_codepoint_iter()
-        iter.move(i)
-        if i != 4:
-            assert iter.peek_next() == [0x1F63D, 0x20AC, 0x10F, 0x41][i]
-        l = list(iter)
-        assert l == [0x1F63D, 0x20AC, 0x10F, 0x41][i:]
-
-    for i in range(1, 5):
-        iter = s.reverse_codepoint_iter()
-        list(iter) # move the iterator to the end
-        iter.move(-i)
-        l = list(iter)
-        assert l == [0x1F63D, 0x20AC, 0x10F, 0x41][4-i:]
+    i = s.iter(len(s) - 1)
+    while i.pos() >= 0:
+        assert utf8ord(s, i.pos()) == i.current()
+        i.move(-1)
 
 def test_builder_append_slice():
     builder = Utf8Builder()
@@ -146,7 +120,6 @@
     s = Utf8Str(' ')
     assert s.join([]) == u''
 
-    
     assert s.join([Utf8Str('one')]) == u'one'
     assert s.join([Utf8Str('one'), Utf8Str('two')]) == u'one two'
 
diff --git a/pypy/interpreter/utf8.py b/pypy/interpreter/utf8.py
--- a/pypy/interpreter/utf8.py
+++ b/pypy/interpreter/utf8.py
@@ -309,18 +309,15 @@
     def __unicode__(self):
         return unicode(self.bytes, 'utf8')
 
+    def iter(self, start=0):
+        return Utf8Iterator(self, start)
+
     def char_iter(self):
         return Utf8CharacterIter(self)
 
-    def reverse_char_iter(self):
-        return Utf8ReverseCharacterIter(self)
-
     def codepoint_iter(self):
         return Utf8CodePointIter(self)
 
-    def reverse_codepoint_iter(self):
-        return Utf8ReverseCodePointIter(self)
-
     @specialize.argtype(1, 2)
     def _bound_check(self, start, end):
         if start is None:
@@ -432,7 +429,7 @@
             else:
                 break
 
-            start_byte = iter.byte_pos
+            start_byte = iter._byte_pos
             assert start_byte >= 0
 
             if maxsplit == 0:
@@ -449,7 +446,7 @@
                            self._is_ascii))
                 break
 
-            end = iter.byte_pos
+            end = iter._byte_pos
             assert end >= 0
             res.append(Utf8Str(self.bytes[start_byte:end], self._is_ascii))
             maxsplit -= 1
@@ -466,32 +463,32 @@
                 other_bytes = other.bytes
             return [Utf8Str(s) for s in self.bytes.rsplit(other_bytes, maxsplit)]
 
+        if len(self) == 0:
+            return []
+
         res = []
-        iter = self.reverse_codepoint_iter()
+        iter = self.iter(len(self) - 1)
         while True:
             # Find the start of the next word
-            for cd in iter:
-                if not unicodedb.isspace(cd):
-                    break
-            else:
+            while iter.pos() >= 0 and unicodedb.isspace(iter.current()):
+                iter.move(-1)
+            if iter.pos() < 0:
                 break
 
-            start_byte = self.next_char(iter.byte_pos)
-
+            start_byte = self.next_char(iter.byte_pos())
             if maxsplit == 0:
                 res.append(Utf8Str(self.bytes[0:start_byte], self._is_ascii))
                 break
 
             # Find the end of the word
-            for cd in iter:
-                if unicodedb.isspace(cd):
-                    break
-            else:
+            while iter.pos() >= 0 and not unicodedb.isspace(iter.current()):
+                iter.move(-1)
+            if iter.pos() < 0:
                 # We hit the end of the string
                 res.append(Utf8Str(self.bytes[0:start_byte], self._is_ascii))
                 break
 
-            end_byte = self.next_char(iter.byte_pos)
+            end_byte = self.next_char(iter.byte_pos())
             res.append(Utf8Str(self.bytes[end_byte:start_byte],
                                self._is_ascii))
             maxsplit -= 1
@@ -756,117 +753,27 @@
 
 # _______________________________________________
 
-# iter.current is the current (ie the last returned) element
-# iter.pos isthe position of the current element
-# iter.byte_pos isthe byte position of the current element
-# In the before-the-start state, for foward iterators iter.pos and
-# iter.byte_pos are -1. For reverse iterators, they are len(ustr) and
-# len(ustr.bytes) respectively.
-
 class ForwardIterBase(object):
     def __init__(self, ustr):
-        self.ustr = ustr
-        self.pos = -1
-
-        self._byte_pos = 0
-        self.byte_pos = -1
-        self.current = self._default
+        self._str = ustr
+        self._byte_pos = -1
 
     def __iter__(self):
         return self
 
     def next(self):
-        if self.pos + 1 == len(self.ustr):
+        if self._byte_pos == -1:
+            if len(self._str) == 0:
+                raise StopIteration()
+            self._byte_pos = 0
+            return self._value(0)
+
+        self._byte_pos = self._str.next_char(self._byte_pos)
+        if self._byte_pos == len(self._str.bytes):
             raise StopIteration()
 
-        self.pos += 1
-        self.byte_pos = self._byte_pos
-
-        self.current = self._value(self.byte_pos)
-
-        self._byte_pos = self.ustr.next_char(self._byte_pos)
-        return self.current
-
-    def peek_next(self):
         return self._value(self._byte_pos)
 
-    def peek_prev(self):
-        return self._value(self._move_backward(self.byte_pos))
-
-    def move(self, count):
-        if count > 0:
-            self.pos += count
-
-            while count != 1:
-                self._byte_pos = self.ustr.next_char(self._byte_pos)
-                count -= 1
-            self.byte_pos = self._byte_pos
-            self._byte_pos = self.ustr.next_char(self._byte_pos)
-            self.current = self._value(self.byte_pos)
-
-        elif count < 0:
-            self.pos += count
-            while count < -1:
-                self.byte_pos = self.ustr.prev_char(self.byte_pos)
-                count += 1
-            self._byte_pos = self.byte_pos
-            self.byte_pos = self.ustr.prev_char(self.byte_pos)
-            self.current = self._value(self.byte_pos)
-
-    def copy(self):
-        iter = self.__class__(self.ustr)
-        iter.pos = self.pos
-        iter.byte_pos = self.byte_pos
-        iter._byte_pos = self._byte_pos
-        iter.current = self.current
-        return iter
-
-class ReverseIterBase(object):
-    def __init__(self, ustr):
-        self.ustr = ustr
-        self.pos = len(ustr)
-        self.byte_pos = len(ustr.bytes)
-        self.current = self._default
-
-    def __iter__(self):
-        return self
-
-    def next(self):
-        if self.pos == 0:
-            raise StopIteration()
-
-        self.pos -= 1
-        self.byte_pos = self.ustr.prev_char(self.byte_pos)
-        self.current = self._value(self.byte_pos)
-        return self.current
-
-    def peek_next(self):
-        return self._value(self.ustr.prev_char(self.byte_pos))
-
-    def peek_prev(self):
-        return self._value(self.ustr.next_char(self.byte_pos))
-
-    def move(self, count):
-        if count > 0:
-            self.pos -= count
-            while count != 0:
-                self.byte_pos = self.ustr.prev_char(self.byte_pos)
-                count -= 1
-            self.current = self._value(self.byte_pos)
-        elif count < 0:
-            self.pos -= count
-            while count != 0:
-                self.byte_pos = self.ustr.next_char(self.byte_pos)
-                count += 1
-            self.current = self._value(self.byte_pos)
-
-    def copy(self):
-        iter = self.__class__(self.ustr)
-        iter.pos = self.pos
-        iter.byte_pos = self.byte_pos
-        iter.current = self.current
-        return iter
-
 def make_iterator(name, base, calc_value, default):
     class C(object):
         import_from_mixin(base, ['__init__', '__iter__'])
@@ -876,32 +783,91 @@
     return C
 
 def codepoint_calc_value(self, byte_pos):
-    if byte_pos == -1 or byte_pos == len(self.ustr.bytes):
+    if byte_pos == -1 or byte_pos == len(self._str.bytes):
         return -1
-    return utf8ord_bytes(self.ustr.bytes, byte_pos)
+    return utf8ord_bytes(self._str.bytes, byte_pos)
 
 def character_calc_value(self, byte_pos):
-    if byte_pos == -1 or byte_pos == len(self.ustr.bytes):
+    if byte_pos == -1 or byte_pos == len(self._str.bytes):
         return None
-    length = utf8_code_length[ord(self.ustr.bytes[self.byte_pos])]
-    return Utf8Str(''.join([self.ustr.bytes[i]
-                    for i in range(self.byte_pos, self.byte_pos + length)]),
+    length = utf8_code_length[ord(self._str.bytes[self._byte_pos])]
+    return Utf8Str(''.join([self._str.bytes[i]
+                    for i in range(self._byte_pos, self._byte_pos + length)]),
                     length == 1)
 
 Utf8CodePointIter = make_iterator("Utf8CodePointIter", ForwardIterBase,
                                   codepoint_calc_value, -1)
 Utf8CharacterIter = make_iterator("Utf8CharacterIter", ForwardIterBase,
                                   character_calc_value, None)
-Utf8ReverseCodePointIter = make_iterator(
-    "Utf8ReverseCodePointIter", ReverseIterBase, codepoint_calc_value, -1)
-Utf8ReverseCharacterIter = make_iterator(
-    "Utf8ReverseCharacterIter", ReverseIterBase, character_calc_value, None)
 
 del make_iterator
 del codepoint_calc_value
 del character_calc_value
 del ForwardIterBase
-del ReverseIterBase
 
 
 
+# _______________________________________________
+
+class Utf8Iterator(object):
+    def __init__(self, str, start=0):
+        self._str = str
+
+        self._pos = start
+        self._byte_pos = str.index_of_char(start)
+
+        self._calc_current()
+
+    def _calc_current(self):
+        if self._pos >= len(self._str) or self._pos < 0:
+            raise IndexError()
+        else:
+            self._current = utf8ord_bytes(self._str.bytes, self._byte_pos)
+
+    def current(self):
+        if self._current == -1:
+            self._calc_current()
+        return self._current
+
+    def pos(self):
+        return self._pos
+
+    def byte_pos(self):
+        return self._byte_pos
+
+    def move(self, count):
+        # TODO: As an optimization, we could delay moving byte_pos until we
+        #       _calc_current
+        if count > 0:
+            self._pos += count
+
+            if self._pos < 0:
+                self._byte_pos = 0
+            else:
+                while count != 0:
+                    self._byte_pos = self._str.next_char(self._byte_pos)
+                    count -= 1
+            self._current = -1
+
+        elif count < 0:
+            self._pos += count
+
+            if self._pos < 0:
+                self._byte_pos = 0
+            else:
+                while count < 0:
+                    self._byte_pos = self._str.prev_char(self._byte_pos)
+                    count += 1
+            self._current = -1
+
+    def finished(self):
+        return self._pos >= len(self._str)
+
+    def copy(self):
+        i = Utf8Iterator(self._str)
+        i._pos = self._pos
+        i._byte_pos = self._byte_pos
+        i._current = self._current
+        return i
+
+
diff --git a/pypy/interpreter/utf8_codecs.py b/pypy/interpreter/utf8_codecs.py
--- a/pypy/interpreter/utf8_codecs.py
+++ b/pypy/interpreter/utf8_codecs.py
@@ -327,15 +327,15 @@
     if size == 0:
         return ''
     result = StringBuilder(size)
-    pos = 0
-    while pos < size:
-        oc = utf8ord(s, pos)
-
+    iter = s.iter()
+    while not iter.finished():
+        oc = iter.current()
         if oc < 0x100:
             result.append(chr(oc))
         else:
             raw_unicode_escape_helper(result, oc)
-        pos += 1
+
+        iter.move(1)
 
     return result.build()
 
@@ -397,28 +397,29 @@
     if size == 0:
         return ''
     result = StringBuilder(size)
-    pos = 0
-    while pos < size:
-        od = utf8ord(p, pos)
+    iter = p.iter()
+    while not iter.finished():
+        od = iter.current()
 
         if od < limit:
             result.append(chr(od))
-            pos += 1
+            iter.move(1)
         else:
-            # startpos for collecting unencodable chars
-            collstart = pos
-            collend = pos+1
-            while collend < len(p) and utf8ord(p, collend) >= limit:
-                collend += 1
+            coll = iter.copy()
+            while not coll.finished() and coll.current() >= limit:
+                coll.move(1)
+            collstart = iter.pos()
+            collend = coll.pos()
+
             ru, rs, pos = errorhandler(errors, encoding, reason, p,
                                        collstart, collend)
+            iter.move(pos - iter.pos())
             if rs is not None:
                 # py3k only
                 result.append(rs)
                 continue
 
-            for ch in ru:
-                cd = utf8ord(ch, 0)
+            for cd in ru.codepoint_iter():
                 if cd < limit:
                     result.append(chr(cd))
                 else:
@@ -452,41 +453,48 @@
                                      allow_surrogates)
 
 def unicode_encode_utf_8_impl(s, size, errors, errorhandler, allow_surrogates):
-    iter = s.codepoint_iter()
-    for oc in iter:
+    iter = s.iter()
+
+    while not iter.finished():
+        oc = iter.current()
         if oc >= 0xD800 and oc <= 0xDFFF:
             break
-        if iter.pos == size:
-            return s.bytes
-    else:
+        iter.move(1)
+    if iter.finished():
         return s.bytes
 
     result = Utf8Builder(len(s.bytes))
-    result.append_slice(s.bytes, 0, iter.byte_pos)
+    result.append_slice(s.bytes, 0, iter.byte_pos())
 
-    iter.move(-1)
-    for oc in iter:
+    while not iter.finished():
+        oc = iter.current()
+        iter.move(1)
+
         if oc >= 0xD800 and oc <= 0xDFFF:
             # Check the next character to see if this is a surrogate pair
-            if (iter.pos != len(s) and oc <= 0xDBFF and
-                0xDC00 <= iter.peek_next() <= 0xDFFF):
-                oc2 = iter.next()
+            if (not iter.finished() and oc <= 0xDBFF and
+                0xDC00 <= iter.current() <= 0xDFFF):
+
+                oc2 = iter.current()
                 result.append_codepoint(
                         ((oc - 0xD800) << 10 | (oc2 - 0xDC00)) + 0x10000)
+                iter.move(1)
+
             elif allow_surrogates:
                 result.append_codepoint(oc)
             else:
                 ru, rs, pos = errorhandler(errors, 'utf8',
                                         'surrogates not allowed', s,
-                                        iter.pos-1, iter.pos)
-                iter.move(pos - iter.pos)
+                                        iter.pos()-2, iter.pos()-1)
+                iter.move(pos - iter.pos())
                 if rs is not None:
                     # py3k only
                     result.append_utf8(rs)
+                    iter.move(1)
                     continue
-                for ch in ru:
-                    if ord(ch) < 0x80:
-                        result.append_ascii(ch)
+                for ch in ru.codepoint_iter():
+                    if ch < 0x80:
+                        result.append_ascii(chr(ch))
                     else:
                         errorhandler('strict', 'utf8',
                                     'surrogates not allowed',
@@ -809,10 +817,10 @@
         _STORECHAR(result, 0xFEFF, BYTEORDER)
         byteorder = BYTEORDER
 
-    i = 0
-    while i < size:
-        ch = utf8ord(s, i)
-        i += 1
+    iter = s.iter()
+    while not iter.finished():
+        ch = iter.current()
+        iter.move(1)
         ch, ch2 = create_surrogate_pair(ch)
 
         _STORECHAR(result, ch, byteorder)
@@ -980,16 +988,16 @@
         _STORECHAR32(result, 0xFEFF, BYTEORDER)
         byteorder = BYTEORDER
 
-    i = 0
-    while i < size:
-        ch = utf8ord(s, i)
-        i += 1
+    iter = s.iter()
+    while not iter.finished():
+        ch = iter.current()
+        iter.move(1)
         ch2 = 0
         if MAXUNICODE < 65536 and 0xD800 <= ch <= 0xDBFF and i < size:
-            ch2 = ord(s[i])
+            ch2 = iter.current()
             if 0xDC00 <= ch2 <= 0xDFFF:
                 ch = (((ch & 0x3FF)<<10) | (ch2 & 0x3FF)) + 0x10000;
-                i += 1
+                iter.move(1)
         _STORECHAR32(result, ch, byteorder)
 
     return result.build()
@@ -1228,10 +1236,9 @@
     base64bits = 0
     base64buffer = 0
 
-    # TODO: Looping like this is worse than O(n)
-    pos = 0
-    while pos < size:
-        oc = utf8ord(s, pos)
+    iter = s.iter()
+    while not iter.finished():
+        oc = iter.current()
         if not inShift:
             if oc == ord('+'):
                 result.append('+-')
@@ -1260,7 +1267,7 @@
             else:
                 base64bits, base64buffer = _utf7_ENCODE_CHAR(
                     result, oc, base64bits, base64buffer)
-        pos += 1
+        iter.move(1)
 
     if base64bits:
         result.append(_utf7_TO_BASE64(base64buffer << (6 - base64bits)))
@@ -1318,15 +1325,17 @@
     if size == 0:
         return ''
     result = StringBuilder(size)
-    pos = 0
-    while pos < size:
-        ch = s[pos]
+
+    iter = s.iter()
+    while not iter.finished():
+        ch = utf8chr(iter.current())
 
         c = mapping.get(ch, '')
         if len(c) == 0:
             ru, rs, pos = errorhandler(errors, "charmap",
                                        "character maps to <undefined>",
-                                       s, pos, pos + 1)
+                                       s, iter.pos(), iter.pos() + 1)
+            iter.move(pos - iter.pos())
             if rs is not None:
                 # py3k only
                 result.append(rs)
@@ -1337,11 +1346,11 @@
                     errorhandler(
                         "strict", "charmap",
                         "character maps to <undefined>",
-                        s,  pos, pos + 1)
+                        s, iter.pos(), iter.pos() + 1)
                 result.append(c2)
             continue
         result.append(c)
-        pos += 1
+        iter.move(1)
     return result.build()
 
 # }}}
@@ -1367,9 +1376,9 @@
         errorhandler = default_unicode_error_decode
 
     if BYTEORDER == 'little':
-        iorder = [0, 1, 2, 3]
+        iorder = (0, 1, 2, 3)
     else:
-        iorder = [3, 2, 1, 0]
+        iorder = (3, 2, 1, 0)
 
     if size == 0:
         return Utf8Str(''), 0
@@ -1542,30 +1551,35 @@
     if size == 0:
         return ''
     result = StringBuilder(size)
-    pos = 0
-    while pos < size:
-        ch = utf8ord(s, pos)
+
+    iter = s.iter()
+    while not iter.finished():
+        ch = iter.current()
+
         if unicodedb.isspace(ch):
             result.append(' ')
-            pos += 1
+            iter.move(1)
             continue
+
         try:
             decimal = unicodedb.decimal(ch)
         except KeyError:
             pass
         else:
             result.append(chr(48 + decimal))
-            pos += 1
+            iter.move(1)
             continue
+
         if 0 < ch < 256:
             result.append(chr(ch))
-            pos += 1
+            iter.move(1)
             continue
+
         # All other characters are considered unencodable
-        collstart = pos
-        collend = collstart + 1
-        while collend < size:
-            ch = utf8ord(s, collend)
+        colliter = iter.copy()
+        colliter.move(1)
+        while not colliter.finished():
+            ch = colliter.current()
             try:
                 if (0 < ch < 256 or
                     unicodedb.isspace(ch) or
@@ -1574,15 +1588,19 @@
             except KeyError:
                 # not a decimal
                 pass
-            collend += 1
+            colliter.move(1)
+
+        collstart = iter.pos()
+        collend = colliter.pos()
+
         msg = "invalid decimal Unicode string"
         ru, rs, pos = errorhandler(errors, 'decimal',
                                    msg, s, collstart, collend)
+        iter.move(pos - iter.pos())
         if rs is not None:
             # py3k only
             errorhandler('strict', 'decimal', msg, s, collstart, collend)
-        for i in range(len(ru)):
-            ch = utf8.ORD(ru, i)
+        for ch in ru.codepoint_iter():
             if unicodedb.isspace(ch):
                 result.append(' ')
                 continue


More information about the pypy-commit mailing list