[pypy-svn] r48706 - in pypy/branch/ropes-unicode/pypy/objspace/std: . test

cfbolz at codespeak.net cfbolz at codespeak.net
Thu Nov 15 13:22:16 CET 2007


Author: cfbolz
Date: Thu Nov 15 13:22:15 2007
New Revision: 48706

Modified:
   pypy/branch/ropes-unicode/pypy/objspace/std/rope.py
   pypy/branch/ropes-unicode/pypy/objspace/std/test/test_rope.py
Log:
major refactoring of ropes to have a chance to work with unicode. one test
failing.


Modified: pypy/branch/ropes-unicode/pypy/objspace/std/rope.py
==============================================================================
--- pypy/branch/ropes-unicode/pypy/objspace/std/rope.py	(original)
+++ pypy/branch/ropes-unicode/pypy/objspace/std/rope.py	Thu Nov 15 13:22:15 2007
@@ -9,8 +9,8 @@
 
 # XXX should optimize the numbers
 NEW_NODE_WHEN_LENGTH = 16
+CONVERT_WHEN_SMALLER = 8
 MAX_DEPTH = 32 # maybe should be smaller
-MIN_SLICE_LENGTH = 64
 CONCATENATE_WHEN_MULTIPLYING = 128
 HIGHEST_BIT_SET = intmask(1L << (NBITS - 1))
 
@@ -54,58 +54,84 @@
 class StringNode(object):
     hash_cache = 0
     def length(self):
-        return 0
+        raise NotImplementedError("base class")
+
+    def is_ascii(self):
+        raise NotImplementedError("base class")
+        
+    def is_bytestring(self):
+        raise NotImplementedError("base class")
 
     def depth(self):
         return 0
 
-    def rebalance(self):
-        return self
-
     def hash_part(self):
         raise NotImplementedError("base class")
 
-    def flatten(self):
-        return ''
+    def check_balanced(self):
+        return True
 
-    def __add__(self, other):
-        return concatenate(self, other)
-    
-    def __getitem__(self, index):
-        if isinstance(index, slice):
-            start, stop, step = index.indices(self.length())
-            # XXX sucks
-            slicelength = len(xrange(start, stop, step))
-            return getslice(self, start, stop, step, slicelength)
-        return self.getitem(index)
+    def getchar(self, index):
+        raise NotImplementedError("abstract base class")
 
-    def getitem(self, index):
+    def getunichar(self, index):
         raise NotImplementedError("abstract base class")
 
-    def getitem_slice(self, start, stop):
-        # XXX really horrible, in most cases
-        result = []
-        for i in range(start, stop):
-            result.append(self.getitem(i))
-        return rope_from_charlist(result)
+    def getint(self, index):
+        raise NotImplementedError("abstract base class")
+
+    def getslice(self, start, stop):
+        raise NotImplementedError("abstract base class")
 
     def view(self):
         view([self])
 
-    def check_balanced(self):
-        return True
+    def rebalance(self):
+        return self
+
+    def flatten_string(self):
+        raise NotImplementedError("abstract base class")
+
+    def flatten_unicode(self):
+        raise NotImplementedError("abstract base class")
+
+    def __add__(self, other):
+        return concatenate(self, other)
 
 
-class LiteralStringNode(StringNode):
+class LiteralNode(StringNode):
+    def find_int(self, what, start, stop):
+        raise NotImplementedError("abstract base class")
+
+    def literal_concat(self, other):
+        raise NotImplementedError("abstract base class")
+
+
+class LiteralStringNode(LiteralNode):
     def __init__(self, s):
         self.s = s
+        is_ascii = True
+        for c in s:
+            if ord(c) >= 128:
+                is_ascii = False
+        self._is_ascii = is_ascii
     
     def length(self):
         return len(self.s)
 
-    def flatten(self):
+    def is_ascii(self):
+        return self._is_ascii
+
+    def is_bytestring(self):
+        return True
+
+    def flatten_string(self):
         return self.s
 
+    def flatten_unicode(self):
+        # XXX not RPython yet
+        return self.s.decode('latin-1')
+
     def hash_part(self):
         h = self.hash_cache
         if not h:
@@ -117,13 +143,38 @@
             h = self.hash_cache = x
         return h
 
-    def getitem(self, index):
+    def getchar(self, index):
         return self.s[index]
 
-    def getitem_slice(self, start, stop):
+    def getunichar(self, index):
+        return unicode(self.s[index])
+
+    def getint(self, index):
+        return ord(self.s[index])
+
+    def getslice(self, start, stop):
         assert 0 <= start <= stop
         return LiteralStringNode(self.s[start:stop])
 
+
+    def find_int(self, what, start, stop):
+        if what >= 256:
+            return -1
+        result = self.s.find(chr(what), start, stop)
+        if result == -1:
+            return -1
+        return result
+
+    def literal_concat(self, other):
+        if (isinstance(other, LiteralStringNode) and
+            len(other.s) + len(self.s) < NEW_NODE_WHEN_LENGTH):
+            return LiteralStringNode(self.s + other.s)
+        elif (isinstance(other, LiteralUnicodeNode) and
+              len(other.u) + len(self.s) < NEW_NODE_WHEN_LENGTH and
+              len(self.s) < CONVERT_WHEN_SMALLER):
+            return LiteralUnicodeNode(self.s.decode("latin-1") + other.u)
+        return BinaryConcatNode(self, other)
+
     def dot(self, seen, toplevel=False):
         if self in seen:
             return
@@ -139,6 +190,70 @@
 del i
 
 
+class LiteralUnicodeNode(StringNode):
+    def __init__(self, u):
+        self.u = u
+    
+    def length(self):
+        return len(self.u)
+
+    def flatten_unicode(self):
+        return self.u
+
+    def is_ascii(self):
+        return False # usually not
+        
+    def is_bytestring(self):
+        return False
+
+    def hash_part(self):
+        h = self.hash_cache
+        if not h:
+            x = 0
+            for c in self.u:
+                x = (1000003*x) + ord(c)
+            x = intmask(x)
+            x |= HIGHEST_BIT_SET
+            h = self.hash_cache = x
+        return h
+
+    def getunichar(self, index):
+        return self.u[index]
+
+    def getbyte(self, index):
+        return ord(self.u[index])
+
+    def getslice(self, start, stop):
+        assert 0 <= start <= stop
+        return LiteralUnicodeNode(self.u[start:stop])
+
+    def find_int(self, what, start, stop):
+        result = node.u.find(unichr(what), start, stop)
+        if result == -1:
+            return -1
+        return result
+
+    def literal_concat(self, other):
+        if (isinstance(other, LiteralUnicodeNode) and
+            len(other.u) + len(self.u) < NEW_NODE_WHEN_LENGTH):
+            return LiteralStringNode(self.u + other.u)
+        elif (isinstance(other, LiteralStringNode) and
+              len(other.s) + len(self.u) < NEW_NODE_WHEN_LENGTH and
+              len(other.s) < CONVERT_WHEN_SMALLER):
+            return LiteralUnicodeNode(self.u + other.s.decode("latin-1"))
+        return BinaryConcatNode(self, other)
+
+    def dot(self, seen, toplevel=False):
+        if self in seen:
+            return
+        seen[self] = True
+        addinfo = str(self.s).replace('"', "'") or "_"
+        if len(addinfo) > 10:
+            addinfo = addinfo[:3] + "..." + addinfo[-3:]
+        yield ('"%s" [shape=box,label="length: %s\\n%s"];' % (
+            id(self), len(self.s),
+            repr(addinfo).replace('"', '').replace("\\", "\\\\")))
+
 class BinaryConcatNode(StringNode):
     def __init__(self, left, right):
         self.left = left
@@ -149,6 +264,14 @@
             raise
         self._depth = max(left.depth(), right.depth()) + 1
         self.balanced = False
+        self._is_ascii = left.is_ascii() and right.is_ascii()
+        self._is_bytestring = left.is_bytestring() and right.is_bytestring()
+
+    def is_ascii(self):
+        return self._is_ascii
+
+    def is_bytestring(self):
+        return self._is_bytestring
 
     def check_balanced(self):
         if self.balanced:
@@ -172,16 +295,34 @@
     def depth(self):
         return self._depth
 
-    def getitem(self, index):
+    def getchar(self, index):
+        llen = self.left.length()
+        if index >= llen:
+            return self.right.getchar(index - llen)
+        else:
+            return self.left.getchar(index)
+
+    def getunichar(self, index):
+        llen = self.left.length()
+        if index >= llen:
+            return self.right.getunichar(index - llen)
+        else:
+            return self.left.getunichar(index)
+
+    def getint(self, index):
         llen = self.left.length()
         if index >= llen:
-            return self.right.getitem(index - llen)
+            return self.right.getint(index - llen)
         else:
-            return self.left.getitem(index)
+            return self.left.getint(index)
+
+    def flatten_string(self):
+        f = fringe(self)
+        return "".join([node.flatten_string() for node in f])
 
-    def flatten(self):
+    def flatten_unicode(self):
         f = fringe(self)
-        return "".join([node.flatten() for node in f])
+        return "".join([node.flatten_string() for node in f])
  
     def hash_part(self):
         h = self.hash_cache
@@ -213,119 +354,37 @@
             for line in child.dot(seen):
                 yield line
 
-class SliceNode(StringNode):
-    def __init__(self, start, stop, node):
-        assert 0 <= start <= stop
-        self.start = start
-        self.stop = stop
-        self.node = node
-
-    def length(self):
-        return self.stop - self.start
-
-    def getitem_slice(self, start, stop):
-        return self.node.getitem_slice(self.start + start, self.start + stop)
-
-    def getitem(self, index):
-        return self.node.getitem(self.start + index)
-
-    def flatten(self):
-        return self.node.flatten()[self.start: self.stop]
-
-    def hash_part(self):
-        h = self.hash_cache
-        if not h:
-            x = 0
-            for i in range(self.start, self.stop):
-                x = (1000003*x) + ord(self.node.getitem(i))
-            x = intmask(x)
-            x |= HIGHEST_BIT_SET
-            h = self.hash_cache = x
-        return h
-
-    def dot(self, seen, toplevel=False):
-        if self in seen:
-            return
-        seen[self] = True
-        yield '"%s" [shape=octagon,label="slice\\nstart=%s, stop=%s"];' % (
-                id(self), self.start, self.stop)
-        yield '"%s" -> "%s";' % (id(self), id(self.node))
-        for line in self.node.dot(seen):
-            yield line
-
-class EfficientGetitemWraper(StringNode):
-    def __init__(self, node):
-        assert isinstance(node, BinaryConcatNode)
-        self.node = node
-        self.iter = SeekableCharIterator(node)
-        self.nextpos = 0
-        self.accesses = 0
-        self.seeks = 0
-
-    def length(self):
-        return self.node.length()
-
-    def depth(self):
-        return self.node.depth()
-
-    def rebalance(self):
-        return EfficientGetitemWraper(self.node.rebalance())
-
-    def hash_part(self):
-        return self.node.hash_part()
-
-    def flatten(self):
-        return self.node.flatten()
-
-    def getitem(self, index):
-        self.accesses += 1
-        nextpos = self.nextpos
-        self.nextpos = index + 1
-        if index < nextpos:
-            self.iter.seekback(nextpos - index)
-            self.seeks += nextpos - index
-        elif index > nextpos:
-            self.iter.seekforward(index - nextpos)
-            self.seeks += index - nextpos
-        return self.iter.next()
-
-    def view(self):
-        return self.node.view()
-
-    def check_balanced(self):
-        return self.node.check_balanced()
-
-
 def concatenate(node1, node2):
     if node1.length() == 0:
         return node2
     if node2.length() == 0:
         return node1
-    if (isinstance(node2, LiteralStringNode) and
-        len(node2.s) <= NEW_NODE_WHEN_LENGTH):
-        if isinstance(node1, LiteralStringNode):
-            if len(node1.s) + len(node2.s) <= NEW_NODE_WHEN_LENGTH:
-                return LiteralStringNode(node1.s + node2.s)
+    if isinstance(node2, LiteralNode):
+        if isinstance(node1, LiteralNode):
+            return node1.literal_concat(node2)
         elif isinstance(node1, BinaryConcatNode):
             r = node1.right
-            if isinstance(r, LiteralStringNode):
-                if len(r.s) + len(node2.s) <= NEW_NODE_WHEN_LENGTH:
-                    return BinaryConcatNode(node1.left,
-                                            LiteralStringNode(r.s + node2.s))
+            if isinstance(r, LiteralNode):
+                return BinaryConcatNode(node1.left,
+                                        r.literal_concat(node2))
     result = BinaryConcatNode(node1, node2)
     if result.depth() > MAX_DEPTH: #XXX better check
         return result.rebalance()
     return result
 
-def getslice(node, start, stop, step, slicelength):
+def getslice(node, start, stop, step, slicelength=-1):
+    if slicelength == -1:
+        # XXX for testing only
+        slicelength = len(xrange(start, stop, step))
     if step != 1:
         start, stop, node = find_straddling(node, start, stop)
-        iter = SeekableCharIterator(node)
+        iter = SeekableItemIterator(node)
         iter.seekforward(start)
-        result = [iter.next()]
+        #XXX doesn't work for unicode
+        result = [iter.nextchar()]
         for i in range(slicelength - 1):
             iter.seekforward(step - 1)
-            result.append(iter.next())
+            result.append(iter.nextchar())
         return rope_from_charlist(result)
     return getslice_one(node, start, stop)
 
@@ -342,7 +401,7 @@
             getslice_right(node.left, start),
             getslice_left(node.right, stop - node.left.length()))
     else:
-        return getslice_primitive(node, start, stop)
+        return node.getslice(start, stop)
 
 def find_straddling(node, start, stop):
     while 1:
@@ -371,7 +430,7 @@
             else:
                 return concatenate(getslice_right(node.left, start),
                                    node.right)
-        return getslice_primitive(node, start, node.length())
+        return node.getslice(start, node.length())
 
 def getslice_left(node, stop):
     while 1:
@@ -385,16 +444,9 @@
             else:
                 return concatenate(node.left,
                                    getslice_left(node.right, stop - llen))
-        return getslice_primitive(node, 0, stop)
+        return node.getslice(0, stop)
 
 
-def getslice_primitive(node, start, stop):
-    if stop - start >= MIN_SLICE_LENGTH:
-        if isinstance(node, SliceNode):
-            return SliceNode(start + node.start, stop + node.start,
-                             node.node)
-        return SliceNode(start, stop, node)
-    return node.getitem_slice(start, stop)
 
 def multiply(node, times):
     if times <= 0:
@@ -510,10 +562,37 @@
         size += len(chars)
     return rebalance(nodelist, size)
 
+def rope_from_unicharlist(charlist):
+    nodelist = []
+    length = len(charlist)
+    if length:
+        return LiteralStringNode.EMPTY
+    i = 0
+    while i < length:
+        chunk = []
+        while i < length:
+            c = ord(charlist[i])
+            if c < 256:
+                break
+            chunk.append(unichr(c))
+            i += 1
+        if chunk:
+            nodelist.append(LiteralUnicodeNode("".join(chunk)))
+        chunck = []
+        while i < length:
+            c = ord(charlist[i])
+            if c >= 256:
+                break
+            chunk.append(chr(c))
+            i += 1
+        if chunk:
+            nodelist.append(LiteralStringNode("".join(chunk)))
+    return rebalance(nodelist, length)
+
 # __________________________________________________________________________
 # searching
 
-def find_char(node, c, start=0, stop=-1):
+def find_int(node, what, start=0, stop=-1):
     offset = 0
     length = node.length()
     if stop == -1:
@@ -524,14 +603,11 @@
         start = newstart
         stop = newstop
     assert 0 <= start <= stop
-    if isinstance(node, LiteralStringNode):
-        result = node.s.find(c, start, stop)
-        if result == -1:
-            return -1
-        return result + offset
-    elif isinstance(node, SliceNode):
-        return find_char(node.node, c, node.start + start,
-                         node.start + stop) - node.start + offset
+    if isinstance(node, LiteralNode):
+        pos = node.find_int(what, start, stop)
+        if pos == -1:
+            return pos
+        return pos + offset
     iter = FringeIterator(node)
     #import pdb; pdb.set_trace()
     i = 0
@@ -546,20 +622,10 @@
             continue
         searchstart = max(0, start - i)
         searchstop = min(stop - i, nodelength)
-        if isinstance(fringenode, LiteralStringNode):
-            st = fringenode.s
-            localoffset = 0
-        else:
-            assert isinstance(fringenode, SliceNode)
-            n = fringenode.node
-            assert isinstance(n, LiteralStringNode)
-            st = n.s
-            localoffset = -fringenode.start
-            searchstart += fringenode.start
-            searchstop += fringenode.stop
-        pos = st.find(c, searchstart, searchstop)
+        assert isinstance(fringenode, LiteralNode)
+        pos = fringenode.find_int(what, searchstart, searchstop)
         if pos != -1:
-            return pos + i + offset + localoffset
+            return pos + i + offset
         i += nodelength
     return -1
 
@@ -570,7 +636,7 @@
     if stop > len1 or stop == -1:
         stop = len1
     if len2 == 1:
-        return find_char(node, subnode.getitem(0), start, stop)
+        return find_int(node, subnode.getint(0), start, stop)
     if len2 == 0:
         if (stop - start) < 0:
             return -1
@@ -579,25 +645,27 @@
     return _find_node(node, subnode, start, stop, restart)
 
 def _find(node, substring, start, stop, restart):
+    # XXX
+    assert node.is_bytestring()
     len2 = len(substring)
     i = 0
     m = start
-    iter = SeekableCharIterator(node)
+    iter = SeekableItemIterator(node)
     iter.seekforward(start)
-    c = iter.next()
+    c = iter.nextchar()
     while m + i < stop:
         if c == substring[i]:
             i += 1
             if i == len2:
                 return m
             if m + i < stop:
-                c = iter.next()
+                c = iter.nextchar()
         else:
             # mismatch, go back to the last possible starting pos
             if i==0:
                 m += 1
                 if m + i < stop:
-                    c = iter.next()
+                    c = iter.nextchar()
             else:
                 e = restart[i-1]
                 new_m = m + i - e
@@ -605,7 +673,7 @@
                 seek = m + i - new_m
                 if seek:
                     iter.seekback(m + i - new_m)
-                    c = iter.next()
+                    c = iter.nextchar()
                 m = new_m
                 i = e
     return -1
@@ -613,26 +681,26 @@
 def _find_node(node, subnode, start, stop, restart):
     len2 = subnode.length()
     m = start
-    iter = SeekableCharIterator(node)
+    iter = SeekableItemIterator(node)
     iter.seekforward(start)
-    c = iter.next()
+    c = iter.nextint()
     i = 0
-    subiter = SeekableCharIterator(subnode)
-    d = subiter.next()
+    subiter = SeekableItemIterator(subnode)
+    d = subiter.nextint()
     while m + i < stop:
         if c == d:
             i += 1
             if i == len2:
                 return m
-            d = subiter.next()
+            d = subiter.nextint()
             if m + i < stop:
-                c = iter.next()
+                c = iter.nextint()
         else:
             # mismatch, go back to the last possible starting pos
             if i == 0:
                 m += 1
                 if m + i < stop:
-                    c = iter.next()
+                    c = iter.nextint()
             else:
                 e = restart[i - 1]
                 new_m = m + i - e
@@ -640,20 +708,20 @@
                 seek = m + i - new_m
                 if seek:
                     iter.seekback(m + i - new_m)
-                    c = iter.next()
+                    c = iter.nextint()
                 m = new_m
                 subiter.seekback(i - e + 1)
-                d = subiter.next()
+                d = subiter.nextint()
                 i = e
     return -1
 
 def construct_restart_positions(s):
-    l = len(s)
-    restart = [0] * l
+    length = len(s)
+    restart = [0] * length
     restart[0] = 0
     i = 1
     j = 0
-    while i < l:
+    while i < length:
         if s[i] == s[j]:
             j += 1
             restart[i] = j
@@ -668,43 +736,43 @@
 
 def construct_restart_positions_node(node):
     # really a bit overkill
-    l = node.length()
-    restart = [0] * l
+    length = node.length()
+    restart = [0] * length
     restart[0] = 0
     i = 1
     j = 0
-    iter1 = CharIterator(node)
-    iter1.next()
-    c1 = iter1.next()
-    iter2 = SeekableCharIterator(node)
-    c2 = iter2.next()
-    while i < l:
+    iter1 = ItemIterator(node)
+    iter1.nextint()
+    c1 = iter1.nextint()
+    iter2 = SeekableItemIterator(node)
+    c2 = iter2.nextint()
+    while 1:
         if c1 == c2:
             j += 1
-            if j != l:
-                c2 = iter2.next()
+            if j < length:
+                c2 = iter2.nextint()
             restart[i] = j
             i += 1
-            if i != l:
-                c1 = iter1.next()
+            if i < length:
+                c1 = iter1.nextint()
             else:
                 break
         elif j>0:
             new_j = restart[j-1]
             assert new_j < j
             iter2.seekback(j - new_j)
-            c2 = iter2.next()
+            c2 = iter2.nextint()
             j = new_j
         else:
             restart[i] = 0
             i += 1
-            if i != l:
-                c1 = iter1.next()
+            if i < length:
+                c1 = iter1.nextint()
             else:
                 break
             j = 0
-            iter2 = SeekableCharIterator(node)
-            c2 = iter2.next()
+            iter2 = SeekableItemIterator(node)
+            c2 = iter2.nextint()
     return restart
 
 def view(objs):
@@ -785,37 +853,60 @@
         return result
 
 
-class CharIterator(object):
+class ItemIterator(object):
     def __init__(self, node):
         self.iter = FringeIterator(node)
         self.node = None
         self.nodelength = 0
         self.index = 0
 
-    def next(self):
+
+    def getnode(self):
         node = self.node
         if node is None:
             while 1:
                 node = self.node = self.iter.next()
                 nodelength = self.nodelength = node.length()
                 if nodelength != 0:
-                    break
-            self.index = 0
+                    self.index = 0
+                    return node
+        return node
+
+    def advance_index(self):
         index = self.index
-        result = node.getitem(index)
         if index == self.nodelength - 1:
             self.node = None
         else:
             self.index = index + 1
+
+    def nextchar(self):
+        node = self.getnode()
+        index = self.index
+        result = node.getchar(self.index)
+        self.advance_index()
+        return result
+
+    def nextunichar(self):
+        node = self.getnode()
+        index = self.index
+        result = node.getunichar(self.index)
+        self.advance_index()
         return result
 
-class ReverseCharIterator(object):
+    def nextint(self):
+        node = self.getnode()
+        index = self.index
+        result = node.getint(self.index)
+        self.advance_index()
+        return result
+
+class ReverseItemIterator(object):
     def __init__(self, node):
         self.iter = ReverseFringeIterator(node)
         self.node = None
         self.index = 0
 
-    def next(self):
+    def getnode(self):
         node = self.node
         index = self.index
         if node is None:
@@ -823,16 +914,36 @@
                 node = self.node = self.iter.next()
                 index = self.index = node.length() - 1
                 if index != -1:
-                    break
-        result = node.getitem(index)
-        if index == 0:
+                    return node
+        return node
+
+
+    def advance_index(self):
+        if self.index == 0:
             self.node = None
         else:
-            self.index = index - 1
+            self.index -= 1
+
+    def nextchar(self):
+        node = self.getnode()
+        result = node.getchar(self.index)
+        self.advance_index()
+        return result
+
+    def nextint(self):
+        node = self.getnode()
+        result = node.getint(self.index)
+        self.advance_index()
+        return result
+
+    def nextunichar(self):
+        node = self.getnode()
+        result = node.getunichar(self.index)
+        self.advance_index()
         return result
 
 
-class SeekableCharIterator(object):
+class SeekableItemIterator(object):
     def __init__(self, node):
         self.iter = SeekableFringeIterator(node)
         self.node = self.nextnode()
@@ -848,15 +959,34 @@
         self.index = 0
         return node
 
-    def next(self):
+    
+    def advance_index(self):
+        if self.index == self.nodelength - 1:
+            self.node = None
+        self.index += 1
+
+    def nextchar(self):
         node = self.node
         if node is None:
             node = self.nextnode()
-        index = self.index
-        result = self.node.getitem(index)
-        if self.index == self.nodelength - 1:
-            self.node = None
-        self.index = index + 1
+        result = self.node.getchar(self.index)
+        self.advance_index()
+        return result
+
+    def nextunichar(self):
+        node = self.node
+        if node is None:
+            node = self.nextnode()
+        result = self.node.getunichar(self.index)
+        self.advance_index()
+        return result
+
+    def nextint(self):
+        node = self.node
+        if node is None:
+            node = self.nextnode()
+        result = self.node.getint(self.index)
+        self.advance_index()
         return result
 
     def seekforward(self, numchars):
@@ -898,16 +1028,18 @@
 class FindIterator(object):
     def __init__(self, node, sub, start=0, stop=-1):
         self.node = node
+        self.sub = sub
         len1 = self.length = node.length()
-        substring = self.substring = sub.flatten() # XXX for now
-        len2 = len(substring)
+        len2 = sub.length()
         self.search_length = len2
         if len2 == 0:
             self.restart_positions = None
         elif len2 == 1:
             self.restart_positions = None
         else:
-            self.restart_positions = construct_restart_positions(substring)
+            self.restart_positions = construct_restart_positions_node(sub)
+            # XXX
+            assert self.restart_positions == construct_restart_positions(sub.flatten_string())
         self.start = start
         if stop == -1 or stop > len1:
             stop = len1
@@ -921,8 +1053,8 @@
             self.start += 1
             return start
         elif self.search_length == 1:
-            result = find_char(self.node, self.substring[0],
-                               self.start, self.stop)
+            result = find_int(self.node, self.sub.getint(0),
+                              self.start, self.stop)
             if result == -1:
                 self.start = self.length
                 raise StopIteration
@@ -930,8 +1062,8 @@
             return result
         if self.start >= self.stop:
             raise StopIteration
-        result = _find(self.node, self.substring, self.start,
-                       self.stop, self.restart_positions)
+        result = _find_node(self.node, self.sub, self.start,
+                            self.stop, self.restart_positions)
         if result == -1:
             self.start = self.length
             raise StopIteration
@@ -952,15 +1084,18 @@
     if (isinstance(node1, LiteralStringNode) and
         isinstance(node2, LiteralStringNode)):
         return node1.s == node2.s
-    iter1 = CharIterator(node1)
-    iter2 = CharIterator(node2)
+    if (isinstance(node1, LiteralUnicodeNode) and
+        isinstance(node2, LiteralUnicodeNode)):
+        return node1.u == node2.u
+    iter1 = ItemIterator(node1)
+    iter2 = ItemIterator(node2)
     # XXX could be cleverer and detect partial equalities
     while 1:
         try:
-            c = iter1.next()
+            c = iter1.nextint()
         except StopIteration:
             return True
-        if c != iter2.next():
+        if c != iter2.nextint():
             return False
 
 def compare(node1, node2):
@@ -975,10 +1110,10 @@
 
     cmplen = min(len1, len2)
     i = 0
-    iter1 = CharIterator(node1)
-    iter2 = CharIterator(node2)
+    iter1 = ItemIterator(node1)
+    iter2 = ItemIterator(node2)
     while i < cmplen:
-        diff = ord(iter1.next()) - ord(iter2.next())
+        diff = iter1.nextint() - iter2.nextint()
         if diff != 0:
             return diff
         i += 1
@@ -994,6 +1129,6 @@
         return -1
     x = rope.hash_part()
     x <<= 1 # get rid of the bit that is always set
-    x ^= ord(rope.getitem(0))
+    x ^= rope.getint(0)
     x ^= rope.length()
     return intmask(x)

Modified: pypy/branch/ropes-unicode/pypy/objspace/std/test/test_rope.py
==============================================================================
--- pypy/branch/ropes-unicode/pypy/objspace/std/test/test_rope.py	(original)
+++ pypy/branch/ropes-unicode/pypy/objspace/std/test/test_rope.py	Thu Nov 15 13:22:15 2007
@@ -28,7 +28,7 @@
                 continue
             start = random.randrange(len(st) // 3)
             stop = random.randrange(len(st) // 3 * 2, len(st))
-            curr = curr[start: stop]
+            curr = getslice_one(curr, start, stop)
             st = st[start: stop]
     return curr, st
 
@@ -38,20 +38,20 @@
          LiteralStringNode("d" * 32) + LiteralStringNode("ef" * 32) +
          LiteralStringNode(""))
     assert s.depth() == 3
-    assert s.flatten() == "".join([c * 32 for c in "a", "bc", "d", "ef"])
+    assert s.flatten_string() == "".join([c * 32 for c in "a", "bc", "d", "ef"])
     s = s.rebalance()
-    assert s.flatten() == "".join([c * 32 for c in "a", "bc", "d", "ef"])
+    assert s.flatten_string() == "".join([c * 32 for c in "a", "bc", "d", "ef"])
 
 def test_dont_rebalance_again():
     s = (LiteralStringNode("a" * 32) + LiteralStringNode("b" * 32) +
          LiteralStringNode("d" * 32) + LiteralStringNode("e" * 32) +
          LiteralStringNode(""))
     assert s.depth() == 3
-    assert s.flatten() == "".join([c * 32 for c in "abde"])
+    assert s.flatten_string() == "".join([c * 32 for c in "abde"])
     s = s.rebalance()
     assert s.check_balanced()
     assert s.balanced
-    assert s.flatten() == "".join([c * 32 for c in "abde"])
+    assert s.flatten_string() == "".join([c * 32 for c in "abde"])
 
 def test_random_addition_test():
     seed = random.randrange(10000)
@@ -67,9 +67,9 @@
         else:
             curr = LiteralStringNode(a) + curr
             st = a + st
-        assert curr.flatten() == st
+        assert curr.flatten_string() == st
     curr = curr.rebalance()
-    assert curr.flatten() == st
+    assert curr.flatten_string() == st
 
 def test_getitem():
     result = "".join([c * 32 for c in "a", "bc", "d", "ef"])
@@ -79,7 +79,8 @@
     s2 = s1.rebalance()
     for i in range(len(result)):
         for s in [s1, s2]:
-            assert s[i] == result[i]
+            assert s.getchar(i) == result[i]
+            assert s.getint(i) == ord(result[i])
 
 def test_getslice():
     result = "".join([c * 32 for c in "a", "bc", "d", "ef"])
@@ -90,14 +91,14 @@
     for s in [s1, s2]:
         for start in range(0, len(result)):
             for stop in range(start, len(result)):
-                assert s[start:stop].flatten() == result[start:stop]
+                assert getslice_one(s, start, stop).flatten_string() == result[start:stop]
 
 def test_getslice_bug():
     s1 = LiteralStringNode("/home/arigo/svn/pypy/branch/rope-branch/pypy/bin")
     s2 = LiteralStringNode("/pypy")
     s = s1 + s2
     r = getslice_one(s, 1, 5)
-    assert r.flatten() == "home"
+    assert r.flatten_string() == "home"
 
 
 def test_getslice_step():
@@ -105,13 +106,13 @@
           LiteralStringNode("nopqrstu") + LiteralStringNode("vwxyz") + 
           LiteralStringNode("zyxwvut") + LiteralStringNode("srqpomnlk"))
     s2 = s1.rebalance()
-    result = s1.flatten()
-    assert s2.flatten() == result
+    result = s1.flatten_string()
+    assert s2.flatten_string() == result
     for s in [s1, s2]:
         for start in range(0, len(result)):
             for stop in range(start, len(result)):
                 for step in range(1, stop - start):
-                    assert s[start:stop:step].flatten() == result[start:stop:step]
+                    assert getslice(s, start, stop, step).flatten_string() == result[start:stop:step]
 
 
 def test_random_addition_and_slicing():
@@ -141,27 +142,27 @@
             #import pdb; pdb.set_trace()
             start = random.randrange(len(st) // 3)
             stop = random.randrange(len(st) // 3 * 2, len(st))
-            curr = curr[start: stop]
+            curr = getslice_one(curr, start, stop)
             st = st[start: stop]
-        assert curr.flatten() == st
+        assert curr.flatten_string() == st
     curr = curr.rebalance()
-    assert curr.flatten() == st
+    assert curr.flatten_string() == st
 
 def test_iteration():
     rope, real_st = make_random_string(200)
-    iter = CharIterator(rope)
+    iter = ItemIterator(rope)
     for c in real_st:
-        c2 = iter.next()
+        c2 = iter.nextchar()
         assert c2 == c
-    py.test.raises(StopIteration, iter.next)
+    py.test.raises(StopIteration, iter.nextchar)
 
 def test_reverse_iteration():
     rope, real_st = make_random_string(200)
-    iter = ReverseCharIterator(rope)
+    iter = ReverseItemIterator(rope)
     for c in py.builtin.reversed(real_st):
-        c2 = iter.next()
+        c2 = iter.nextchar()
         assert c2 == c
-    py.test.raises(StopIteration, iter.next)
+    py.test.raises(StopIteration, iter.nextchar)
 
 def test_multiply():
     strs = [(LiteralStringNode("a"), "a"), (LiteralStringNode("abc"), "abc"),
@@ -175,7 +176,7 @@
     for r, st in strs:
         for i in times:
             r2 = multiply(r, i)
-            assert r2.flatten() == st * i
+            assert r2.flatten_string() == st * i
 
 def test_join():
     seps = [(LiteralStringNode("a"), "a"), (LiteralStringNode("abc"), "abc"),
@@ -186,7 +187,7 @@
     l = list(l)
     for s, st in seps:
         node = join(s, l)
-        result1 = node.flatten()
+        result1 = node.flatten_string()
         result2 = st.join(strs)
         for i in range(node.length()):
             assert result1[i] == result2[i]
@@ -196,7 +197,7 @@
                ':', '213', '>']
     l = [LiteralStringNode(s) for s in strings]
     node = join(LiteralStringNode(""), l)
-    assert node.flatten() == ''.join(strings)
+    assert node.flatten_string() == ''.join(strings)
 
 def test_join_random():
     l, strs = zip(*[make_random_string(10 * i) for i in range(1, 5)])
@@ -205,7 +206,7 @@
             make_random_string(500)]
     for s, st in seps:
         node = join(s, l)
-        result1 = node.flatten()
+        result1 = node.flatten_string()
         result2 = st.join(strs)
         for i in range(node.length()):
             assert result1[i] == result2[i]
@@ -214,18 +215,18 @@
     rope = BinaryConcatNode(BinaryConcatNode(LiteralStringNode("abc"),
                                              LiteralStringNode("def")),
                             LiteralStringNode("ghi"))
-    iter = SeekableCharIterator(rope)
+    iter = SeekableItemIterator(rope)
     for c in "abcdefgh":
-        c2 = iter.next()
+        c2 = iter.nextchar()
         assert c2 == c
     for i in range(7):
         iter.seekback(i)
         for c in "abcdefghi"[-1-i:-1]:
-            c2 = iter.next()
+            c2 = iter.nextchar()
             assert c2 == c
-    c2 = iter.next()
+    c2 = iter.nextchar()
     assert c2 == "i"
-    py.test.raises(StopIteration, iter.next)
+    py.test.raises(StopIteration, iter.nextchar)
 
 def test_fringe_iterator():
     ABC = LiteralStringNode("abc")
@@ -277,41 +278,37 @@
                                              LiteralStringNode("def")),
                             LiteralStringNode("ghi"))
     rope = rope + rope
-    result = rope.flatten()
+    result = rope.flatten_string()
     for j in range(len(result) - 1):
         for i in range(len(result) - 1 - j):
-            iter = SeekableCharIterator(rope)
+            iter = SeekableItemIterator(rope)
 #            if (j, i) == (3, 1):
 #                import pdb; pdb.set_trace()
             for c in result[:j]:
-                c2 = iter.next()
+                c2 = iter.nextchar()
                 assert c2 == c
             iter.seekforward(i)
             for c in result[i + j:]:
-                c2 = iter.next()
+                c2 = iter.nextchar()
                 assert c2 == c
-        py.test.raises(StopIteration, iter.next)
+        py.test.raises(StopIteration, iter.nextchar)
 
-def test_find_char():
+def test_find_int():
     rope, st = make_random_string()
-    rope = rope[10:100]
+    rope = getslice_one(rope, 10, 100)
     st = st[10:100]
     for i in range(len(st)):
         print i
         for j in range(i + 1, len(st)):
             c = st[i:j][(j - i) // 2]
-            pos = find_char(rope, c, i, j)
+            pos = find_int(rope, ord(c), i, j)
             assert pos == st.find(c, i, j)
 
-def test_find_char_bugs():
-    r = find_char(LiteralStringNode("ascii"), " ", 0, 5)
+def test_find_int_bugs():
+    r = find_int(LiteralStringNode("ascii"), ord(" "), 0, 5)
     assert r == -1
-    r = find_char(LiteralStringNode("a"), "a")
+    r = find_int(LiteralStringNode("a"), ord("a"))
     assert r == 0
-    r = find_char(BinaryConcatNode(
-        LiteralStringNode("x"),
-        SliceNode(1, 9, LiteralStringNode("a" * 10))), "a")
-    assert r == 1
 
 
 def test_restart_positions():
@@ -335,16 +332,16 @@
     for i in range(0, 100, 10):
         chars = ["a"] * 50 + ["b"] * i
         node = rope_from_charlist(chars)
-        assert node.flatten() == "a" * 50  + "b" * i
-    assert rope_from_charlist([]).flatten() == ""
+        assert node.flatten_string() == "a" * 50  + "b" * i
+    assert rope_from_charlist([]).flatten_string() == ""
 
 def test_find_iterator():
-    for searchstring in ["abc", "a", "", "x", "xyz"]:
+    for searchstring in ["abc", "a", "", "x", "xyz", "abababcabcabb"]:
         node = join(LiteralStringNode(searchstring),
                     [LiteralStringNode("cde" * i) for i in range(1, 10)])
-    #   node.view()
+        #node.view()
         iter = FindIterator(node, LiteralStringNode(searchstring))
-        s = node.flatten()
+        s = node.flatten_string()
         assert s == searchstring.join(["cde" * i for i in range(1, 10)])
         start = 0
         while 1:
@@ -365,10 +362,10 @@
             assert hash_rope(rope) == -1
             continue
         h = hash_rope(rope)
-        x = LiteralStringNode(rope.flatten()).hash_part()
+        x = LiteralStringNode(rope.flatten_string()).hash_part()
         assert x == rope.hash_part()
         x <<= 1
-        x ^= ord(rope.getitem(0))
+        x ^= rope.getint(0)
         x ^= rope.length()
         assert intmask(x) == h
         # hash again to check for cache effects
@@ -496,24 +493,25 @@
         for j in range(1, 10000, 7):
             assert intmask(i ** j) == masked_power(i, j)
 
-def test_EfficientGetitemWraper():
-    node1, _ = make_random_string(slicing=False)
-    node2 = EfficientGetitemWraper(node1)
-    for i in range(node2.length()):
-        assert node1.getitem(i) == node2.getitem(i)
-    for j in range(1000):
-        i = random.randrange(node1.length())
-        assert node1.getitem(i) == node2.getitem(i)
 
 def test_seekable_bug():
     node = BinaryConcatNode(LiteralStringNode("abc"), LiteralStringNode("def"))
-    iter = SeekableCharIterator(node)
-    c = iter.next(); assert c == "a"
-    c = iter.next(); assert c == "b"
-    c = iter.next(); assert c == "c"
+    iter = SeekableItemIterator(node)
+    c = iter.nextchar(); assert c == "a"
+    c = iter.nextchar(); assert c == "b"
+    c = iter.nextchar(); assert c == "c"
     iter.seekback(1)
-    c = iter.next(); assert c == "c"
-    c = iter.next(); assert c == "d"
-    c = iter.next(); assert c == "e"
-    c = iter.next(); assert c == "f"
-    py.test.raises(StopIteration, iter.next)
+    c = iter.nextchar(); assert c == "c"
+    c = iter.nextchar(); assert c == "d"
+    c = iter.nextchar(); assert c == "e"
+    c = iter.nextchar(); assert c == "f"
+    py.test.raises(StopIteration, iter.nextchar)
+    node = LiteralStringNode("abcdef")
+    iter = SeekableItemIterator(node)
+    c = iter.nextchar(); assert c == "a"
+    c = iter.nextchar(); assert c == "b"
+    c = iter.nextchar(); assert c == "c"
+    iter.seekback(3)
+    c = iter.nextchar(); assert c == "a"
+    c = iter.nextchar(); assert c == "b"
+    c = iter.nextchar(); assert c == "c"



More information about the Pypy-commit mailing list