[pypy-commit] pypy unicode-utf8: progress on having flags correctly propagated, almost there

fijal pypy.commits at gmail.com
Sat Nov 4 18:16:53 EDT 2017


Author: fijal
Branch: unicode-utf8
Changeset: r92934:29ce3a4ea76f
Date: 2017-11-04 14:38 +0100
http://bitbucket.org/pypy/pypy/changeset/29ce3a4ea76f/

Log:	progress on having flags correctly propagated, almost there

diff --git a/TODO b/TODO
--- a/TODO
+++ b/TODO
@@ -4,3 +4,7 @@
   if one is not already readily available
 * fix _pypyjson
 * fix cpyext
+* write the correct jit_elidable in _get_index_storage
+* better flag handling in split/splitlines maybe?
+* find all the fast-paths that we want to do with utf8 (we only do
+  utf-8 now, not UTF8 or utf8) for decode/encode
diff --git a/pypy/interpreter/baseobjspace.py b/pypy/interpreter/baseobjspace.py
--- a/pypy/interpreter/baseobjspace.py
+++ b/pypy/interpreter/baseobjspace.py
@@ -1764,8 +1764,10 @@
         return self.realutf8_w(w_obj).decode('utf8')
 
     def newunicode(self, u):
+        from pypy.interpreter import unicodehelper
         assert isinstance(u, unicode)
-        return self.newutf8(u.encode("utf8"), len(u))
+        # XXX let's disallow that
+        return self.newutf8(u.encode("utf8"), len(u), unicodehelper._get_flag(u))
 
     def convert_to_w_unicode(self, w_obj):
         return w_obj.convert_to_w_unicode(self)
diff --git a/pypy/interpreter/pyparser/parsestring.py b/pypy/interpreter/pyparser/parsestring.py
--- a/pypy/interpreter/pyparser/parsestring.py
+++ b/pypy/interpreter/pyparser/parsestring.py
@@ -59,10 +59,11 @@
         else:
             substr = decode_unicode_utf8(space, s, ps, q)
         if rawmode:
-            v, length = unicodehelper.decode_raw_unicode_escape(space, substr)
+            r = unicodehelper.decode_raw_unicode_escape(space, substr)
         else:
-            v, length = unicodehelper.decode_unicode_escape(space, substr)
-        return space.newutf8(v, length)
+            r = unicodehelper.decode_unicode_escape(space, substr)
+        v, length, flag = r
+        return space.newutf8(v, length, flag)
 
     need_encoding = (encoding is not None and
                      encoding != "utf-8" and encoding != "utf8" and
diff --git a/pypy/interpreter/unicodehelper.py b/pypy/interpreter/unicodehelper.py
--- a/pypy/interpreter/unicodehelper.py
+++ b/pypy/interpreter/unicodehelper.py
@@ -20,11 +20,11 @@
 @specialize.memo()
 def encode_error_handler(space):
     # Fast version of the "strict" errors handler.
-    def raise_unicode_exception_encode(errors, encoding, msg, u, u_len,
+    def raise_unicode_exception_encode(errors, encoding, msg, w_u,
                                        startingpos, endingpos):
         raise OperationError(space.w_UnicodeEncodeError,
                              space.newtuple([space.newtext(encoding),
-                                             space.newutf8(u, u_len),
+                                             w_u,
                                              space.newint(startingpos),
                                              space.newint(endingpos),
                                              space.newtext(msg)]))
@@ -41,6 +41,21 @@
     from pypy.objspace.std.unicodeobject import encode_object
     return encode_object(space, w_data, encoding, errors)
 
+def _has_surrogate(u):
+    for c in u:
+        if 0xDB80 <= ord(c) <= 0xCBFF or 0xD800 <= ord(c) <= 0xDB7F:
+            return True
+    return False
+
+def _get_flag(u):
+    flag = rutf8.FLAG_ASCII
+    for c in u:
+        if 0xDB80 <= ord(c) <= 0xCBFF or 0xD800 <= ord(c) <= 0xDB7F:
+            return rutf8.FLAG_HAS_SURROGATES
+        if ord(c) >= 0x80:
+            flag = rutf8.FLAG_REGULAR
+    return flag
+
 # These functions take and return unwrapped rpython strings and unicodes
 def decode_unicode_escape(space, string):
     state = space.fromcache(interp_codecs.CodecState)
@@ -52,7 +67,14 @@
         final=True, errorhandler=DecodeWrapper(decode_error_handler(space)).handle,
         unicodedata_handler=unicodedata_handler)
     # XXX argh.  we want each surrogate to be encoded separately
-    return ''.join([u.encode('utf8') for u in result_u]), len(result_u)
+    utf8 = ''.join([u.encode('utf8') for u in result_u])
+    if rutf8.first_non_ascii_char(utf8) == -1:
+        flag = rutf8.FLAG_ASCII
+    elif _has_surrogate(result_u):
+        flag = rutf8.FLAG_HAS_SURROGATES
+    else:
+        flag = rutf8.FLAG_REGULAR
+    return utf8, len(result_u), flag
 
 def decode_raw_unicode_escape(space, string):
     # XXX pick better length, maybe
@@ -61,7 +83,14 @@
         string, len(string), "strict",
         final=True, errorhandler=DecodeWrapper(decode_error_handler(space)).handle)
     # XXX argh.  we want each surrogate to be encoded separately
-    return ''.join([u.encode('utf8') for u in result_u]), len(result_u)
+    utf8 = ''.join([u.encode('utf8') for u in result_u])
+    if rutf8.first_non_ascii_char(utf8) == -1:
+        flag = rutf8.FLAG_ASCII
+    elif _has_surrogate(result_u):
+        flag = rutf8.FLAG_HAS_SURROGATES
+    else:
+        flag = rutf8.FLAG_REGULAR
+    return utf8, len(result_u), flag
 
 def check_ascii_or_raise(space, string):
     try:
@@ -78,12 +107,12 @@
     # you still get two surrogate unicode characters in the result.
     # These are the Python2 rules; Python3 differs.
     try:
-        length = rutf8.check_utf8(string, allow_surrogates=True)
+        length, flag = rutf8.check_utf8(string, allow_surrogates=True)
     except rutf8.CheckError as e:
         decode_error_handler(space)('strict', 'utf8', 'invalid utf-8', string,
                                     e.pos, e.pos + 1)
         assert False, "unreachable"
-    return length
+    return length, flag
 
 def encode_utf8(space, uni):
     # DEPRECATED
@@ -116,7 +145,7 @@
     except rutf8.CheckError:
         w = DecodeWrapper((errorhandler))
         u, pos = runicode.str_decode_ascii(s, slen, errors, final, w.handle)
-        return u.encode('utf8'), pos, len(u)
+        return u.encode('utf8'), pos, len(u), _get_flag(u)
 
 # XXX wrappers, think about speed
 
@@ -139,14 +168,14 @@
     w = DecodeWrapper(errorhandler)
     u, pos = runicode.str_decode_utf_8_impl(s, slen, errors, final, w.handle,
         runicode.allow_surrogate_by_default)
-    return u.encode('utf8'), pos, len(u)
+    return u.encode('utf8'), pos, len(u), _get_flag(u)
 
 def str_decode_unicode_escape(s, slen, errors, final, errorhandler, ud_handler):
     w = DecodeWrapper(errorhandler)
     u, pos = runicode.str_decode_unicode_escape(s, slen, errors, final,
                                                 w.handle,
                                                 ud_handler)
-    return u.encode('utf8'), pos, len(u)
+    return u.encode('utf8'), pos, len(u), _get_flag(u)
 
 def setup_new_encoders(encoding):
     encoder_name = 'utf8_encode_' + encoding
@@ -160,7 +189,7 @@
     def decoder(s, slen, errors, final, errorhandler):
         w = DecodeWrapper((errorhandler))
         u, pos = getattr(runicode, decoder_name)(s, slen, errors, final, w.handle)
-        return u.encode('utf8'), pos, len(u)
+        return u.encode('utf8'), pos, len(u), _get_flag(u)
     encoder.__name__ = encoder_name
     decoder.__name__ = decoder_name
     if encoder_name not in globals():
diff --git a/pypy/module/__builtin__/operation.py b/pypy/module/__builtin__/operation.py
--- a/pypy/module/__builtin__/operation.py
+++ b/pypy/module/__builtin__/operation.py
@@ -28,7 +28,13 @@
         s = rutf8.unichr_as_utf8(code, allow_surrogates=True)
     except ValueError:
         raise oefmt(space.w_ValueError, "unichr() arg out of range")
-    return space.newutf8(s, 1)
+    if code < 0x80:
+        flag = rutf8.FLAG_ASCII
+    elif 0xDB80 <= code <= 0xCBFF or 0xD800 <= code <= 0xDB7F:
+        flag = rutf8.FLAG_HAS_SURROGATE
+    else:
+        flag = rutf8.FLAG_REGULAR
+    return space.newutf8(s, 1, flag)
 
 def len(space, w_obj):
     "len(object) -> integer\n\nReturn the number of items of a sequence or mapping."
diff --git a/pypy/module/_codecs/interp_codecs.py b/pypy/module/_codecs/interp_codecs.py
--- a/pypy/module/_codecs/interp_codecs.py
+++ b/pypy/module/_codecs/interp_codecs.py
@@ -39,8 +39,8 @@
                 w_input = space.newbytes(input)
             else:
                 w_cls = space.w_UnicodeEncodeError
-                length = rutf8.check_utf8(input, allow_surrogates=True)
-                w_input = space.newutf8(input, length)
+                length, flag = rutf8.check_utf8(input, allow_surrogates=True)
+                w_input = space.newutf8(input, length, flag)
             w_exc =  space.call_function(
                 w_cls,
                 space.newtext(encoding),
@@ -189,7 +189,7 @@
 def ignore_errors(space, w_exc):
     check_exception(space, w_exc)
     w_end = space.getattr(w_exc, space.newtext('end'))
-    return space.newtuple([space.newutf8('', 0), w_end])
+    return space.newtuple([space.newutf8('', 0, rutf8.FLAG_ASCII), w_end])
 
 REPLACEMENT = u'\ufffd'.encode('utf8')
 
@@ -200,13 +200,13 @@
     size = space.int_w(w_end) - space.int_w(w_start)
     if space.isinstance_w(w_exc, space.w_UnicodeEncodeError):
         text = '?' * size
-        return space.newtuple([space.newutf8(text, size), w_end])
+        return space.newtuple([space.newutf8(text, size, rutf8.FLAG_ASCII), w_end])
     elif space.isinstance_w(w_exc, space.w_UnicodeDecodeError):
         text = REPLACEMENT
-        return space.newtuple([space.newutf8(text, 1), w_end])
+        return space.newtuple([space.newutf8(text, 1, rutf8.FLAG_REGULAR), w_end])
     elif space.isinstance_w(w_exc, space.w_UnicodeTranslateError):
         text = REPLACEMENT * size
-        return space.newtuple([space.newutf8(text, size), w_end])
+        return space.newtuple([space.newutf8(text, size, rutf8.FLAG_REGULAR), w_end])
     else:
         raise oefmt(space.w_TypeError,
                     "don't know how to handle %T in error callback", w_exc)
@@ -403,9 +403,9 @@
         final = space.is_true(w_final)
         state = space.fromcache(CodecState)
         func = getattr(unicodehelper, rname)
-        result, consumed, length = func(string, len(string), errors,
-                                final, state.decode_error_handler)
-        return space.newtuple([space.newutf8(result, length),
+        result, consumed, length, flag = func(string, len(string), errors,
+                                              final, state.decode_error_handler)
+        return space.newtuple([space.newutf8(result, length, flag),
                                space.newint(consumed)])
     wrap_decoder.func_name = rname
     globals()[name] = wrap_decoder
@@ -448,7 +448,7 @@
 # "allow_surrogates=True"
 @unwrap_spec(utf8='utf8', errors='text_or_none')
 def utf_8_encode(space, utf8, errors="strict"):
-    length = rutf8.check_utf8(utf8, allow_surrogates=True)
+    length, _ = rutf8.check_utf8(utf8, allow_surrogates=True)
     return space.newtuple([space.newbytes(utf8), space.newint(length)])
 #@unwrap_spec(uni=unicode, errors='text_or_none')
 #def utf_8_encode(space, uni, errors="strict"):
@@ -474,16 +474,17 @@
     state = space.fromcache(CodecState)
     # call the fast version for checking
     try:
-        lgt = rutf8.check_utf8(string, allow_surrogates=True)
+        lgt, flag = rutf8.check_utf8(string, allow_surrogates=True)
     except rutf8.CheckError as e:
         # XXX do the way around runicode - we can optimize it later if we
         # decide we care about obscure cases
+        xxx
         res, consumed, lgt = unicodehelper.str_decode_utf8(string, len(string),
             errors, final, state.decode_error_handler)
         return space.newtuple([space.newutf8(res, lgt),
                                space.newint(consumed)])
     else:
-        return space.newtuple([space.newutf8(string, lgt),
+        return space.newtuple([space.newutf8(string, lgt, flag),
                                space.newint(len(string))])
 
 @unwrap_spec(data='bufferstr', errors='text_or_none', byteorder=int,
diff --git a/pypy/objspace/std/marshal_impl.py b/pypy/objspace/std/marshal_impl.py
--- a/pypy/objspace/std/marshal_impl.py
+++ b/pypy/objspace/std/marshal_impl.py
@@ -403,8 +403,8 @@
 @unmarshaller(TYPE_UNICODE)
 def unmarshal_unicode(space, u, tc):
     arg = u.get_str()
-    length = unicodehelper.check_utf8_or_raise(space, arg)
-    return space.newutf8(arg, length)
+    length, flag = unicodehelper.check_utf8_or_raise(space, arg)
+    return space.newutf8(arg, length, flag)
 
 @marshaller(W_SetObject)
 def marshal_set(space, w_set, m):
diff --git a/pypy/objspace/std/objspace.py b/pypy/objspace/std/objspace.py
--- a/pypy/objspace/std/objspace.py
+++ b/pypy/objspace/std/objspace.py
@@ -317,8 +317,8 @@
         for utf in lst:
             assert utf is not None
             assert isinstance(utf, str)
-            length = rutf8.check_utf8(utf, allow_surrogates=True)
-            res_w.append(self.newutf8(utf, length))
+            length, flag = rutf8.check_utf8(utf, allow_surrogates=True)
+            res_w.append(self.newutf8(utf, length, flag))
         return self.newlist(res_w)
 
     def newlist_int(self, list_i):
@@ -369,10 +369,10 @@
             return self.w_None
         return self.newtext(s)
 
-    def newutf8(self, utf8s, length):
+    def newutf8(self, utf8s, length, flag):
         assert utf8s is not None
         assert isinstance(utf8s, str)
-        return W_UnicodeObject(utf8s, length)
+        return W_UnicodeObject(utf8s, length, flag)
 
     def newfilename(self, s):
         assert isinstance(s, str) # on pypy3, this decodes the byte string
diff --git a/pypy/objspace/std/test/test_unicodeobject.py b/pypy/objspace/std/test/test_unicodeobject.py
--- a/pypy/objspace/std/test/test_unicodeobject.py
+++ b/pypy/objspace/std/test/test_unicodeobject.py
@@ -3,6 +3,7 @@
 import py
 import sys
 from hypothesis import given, strategies, settings, example
+from rpython.rlib import rutf8
 from pypy.interpreter.error import OperationError
 
 
@@ -27,12 +28,12 @@
 
     def test_listview_unicode(self):
         py.test.skip("skip for new")
-        w_str = self.space.wrap(u'abcd')
+        w_str = self.space.newutf8('abcd', 4, rutf8.FLAG_ASCII)
         assert self.space.listview_unicode(w_str) == list(u"abcd")
 
     def test_new_shortcut(self):
         space = self.space
-        w_uni = self.space.wrap(u'abcd')
+        w_uni = self.space.newutf8('abcd', 4, rutf8.FLAG_ASCII)
         w_new = space.call_method(
                 space.w_unicode, "__new__", space.w_unicode, w_uni)
         assert w_new is w_uni
@@ -44,8 +45,8 @@
             return   # skip this case
         v = u[start : start + len1]
         space = self.space
-        w_u = space.wrap(u)
-        w_v = space.wrap(v)
+        w_u = space.newutf8(u.encode('utf8'), len(u), rutf8.FLAG_REGULAR)
+        w_v = space.newutf8(v.encode('utf8'), len(v), rutf8.FLAG_REGULAR)
         expected = u.find(v, start, start + len1)
         try:
             w_index = space.call_method(w_u, 'index', w_v,
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -36,14 +36,24 @@
     _immutable_fields_ = ['_utf8']
 
     @enforceargs(utf8str=str)
-    def __init__(self, utf8str, length):
+    def __init__(self, utf8str, length, flag):
         assert isinstance(utf8str, str)
         assert length >= 0
         self._utf8 = utf8str
         self._length = length
-        self._index_storage = rutf8.null_storage()
-        #if not we_are_translated():
-        #    assert rutf8.check_utf8(utf8str, allow_surrogates=True) == length
+        if flag == rutf8.FLAG_ASCII:
+            self._index_storage = rutf8.UTF8_IS_ASCII
+        elif flag == rutf8.FLAG_HAS_SURROGATES:
+            self._index_storage = rutf8.UTF8_HAS_SURROGATES
+        else:
+            assert flag == rutf8.FLAG_REGULAR
+            self._index_storage = rutf8.null_storage()
+        # the storage can be one of:
+        # - null, unicode with no surrogates
+        # - rutf8.UTF8_HAS_SURROGATES
+        # - rutf8.UTF8_IS_ASCII
+        # - malloced object, which means it has index, then
+        #   _index_storage.flags determines the kind
 
     def __repr__(self):
         """representation for debugging purposes"""
@@ -222,7 +232,11 @@
 
         assert isinstance(w_value, W_UnicodeObject)
         w_newobj = space.allocate_instance(W_UnicodeObject, w_unicodetype)
-        W_UnicodeObject.__init__(w_newobj, w_value._utf8, w_value._length)
+        W_UnicodeObject.__init__(w_newobj, w_value._utf8, w_value._length,
+                                 w_value._get_flag())
+        if w_value._index_storage:
+            # copy the storage if it's there
+            w_newobj._index_storage = w_value._index_storage
         return w_newobj
 
     def descr_repr(self, space):
@@ -326,29 +340,33 @@
     def descr_swapcase(self, space):
         selfvalue = self._utf8
         builder = StringBuilder(len(selfvalue))
+        flag = self._get_flag()
         i = 0
         while i < len(selfvalue):
             ch = rutf8.codepoint_at_pos(selfvalue, i)
             i = rutf8.next_codepoint_pos(selfvalue, i)
             if unicodedb.isupper(ch):
-                rutf8.unichr_as_utf8_append(builder, unicodedb.tolower(ch))
+                ch = unicodedb.tolower(ch)
             elif unicodedb.islower(ch):
-                rutf8.unichr_as_utf8_append(builder, unicodedb.toupper(ch))
-            else:
-                rutf8.unichr_as_utf8_append(builder, ch)
-        return W_UnicodeObject(builder.build(), self._length)
+                ch = unicodedb.toupper(ch)
+            if ch >= 0x80:
+                flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
+            rutf8.unichr_as_utf8_append(builder, ch)
+        return W_UnicodeObject(builder.build(), self._length, flag)
 
     def descr_title(self, space):
         if len(self._utf8) == 0:
             return self
-        return W_UnicodeObject(self.title(self._utf8), self._len())
+        utf8, flag = self.title_unicode(self._utf8)
+        return W_UnicodeObject(utf8, self._len(), flag)
 
     @jit.elidable
-    def title(self, value):
+    def title_unicode(self, value):
         input = self._utf8
         builder = StringBuilder(len(input))
         i = 0
         previous_is_cased = False
+        flag = self._get_flag()
         while i < len(input):
             ch = rutf8.codepoint_at_pos(input, i)
             i = rutf8.next_codepoint_pos(input, i)
@@ -356,14 +374,17 @@
                 ch = unicodedb.totitle(ch)
             else:
                 ch = unicodedb.tolower(ch)
+            if ch >= 0x80:
+                flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
             rutf8.unichr_as_utf8_append(builder, ch)
             previous_is_cased = unicodedb.iscased(ch)
-        return builder.build()
+        return builder.build(), flag
 
     def descr_translate(self, space, w_table):
         input = self._utf8
         result = StringBuilder(len(input))
         result_length = 0
+        flag = self._get_flag()
         i = 0
         while i < len(input):
             codepoint = rutf8.codepoint_at_pos(input, i)
@@ -380,6 +401,7 @@
                     codepoint = space.int_w(w_newval)
                 elif isinstance(w_newval, W_UnicodeObject):
                     result.append(w_newval._utf8)
+                    flag = self._combine_flags(flag, w_newval._get_flag())
                     result_length += w_newval._length
                     continue
                 else:
@@ -387,13 +409,15 @@
                                 "character mapping must return integer, None "
                                 "or unicode")
             try:
+                if codepoint >= 0x80:
+                    flag = self._combine_flags(flag, rutf8.FLAG_NORMAL)
                 rutf8.unichr_as_utf8_append(result, codepoint,
                                             allow_surrogates=True)
                 result_length += 1
             except ValueError:
                 raise oefmt(space.w_TypeError,
                             "character mapping must be in range(0x110000)")
-        return W_UnicodeObject(result.build(), result_length)
+        return W_UnicodeObject(result.build(), result_length, flag)
 
     def descr_find(self, space, w_sub, w_start=None, w_end=None):
         w_result = self._unwrap_and_search(space, w_sub, w_start, w_end)
@@ -472,7 +496,7 @@
             newlen += dist
             oldtoken = token
 
-        return W_UnicodeObject(expanded, newlen)
+        return W_UnicodeObject(expanded, newlen, self._get_flag())
 
     _StringMethods_descr_join = descr_join
     def descr_join(self, space, w_list):
@@ -506,11 +530,14 @@
     def descr_lower(self, space):
         builder = StringBuilder(len(self._utf8))
         pos = 0
+        flag = self._get_flag()
         while pos < len(self._utf8):
             lower = unicodedb.tolower(rutf8.codepoint_at_pos(self._utf8, pos))
+            if lower >= 0x80:
+                flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
             rutf8.unichr_as_utf8_append(builder, lower) # XXX allow surrogates?
             pos = rutf8.next_codepoint_pos(self._utf8, pos)
-        return W_UnicodeObject(builder.build(), self._len())
+        return W_UnicodeObject(builder.build(), self._len(), flag)
 
     def descr_isdecimal(self, space):
         return self._is_generic(space, '_isdecimal')
@@ -595,6 +622,22 @@
             return True
         return endswith(value, prefix, start, end)
 
+    @staticmethod
+    def _combine_flags(self_flag, other_flag):
+        if self_flag == rutf8.FLAG_ASCII and other_flag == rutf8.FLAG_ASCII:
+            return rutf8.FLAG_ASCII
+        elif (self_flag == rutf8.FLAG_HAS_SURROGATES or
+              other_flag == rutf8.FLAG_HAS_SURROGATES):
+            return rutf8.FLAG_HAS_SURROGATES
+        return rutf8.FLAG_REGULAR
+
+    def _get_flag(self):
+        if self._is_ascii():
+            return rutf8.FLAG_ASCII
+        elif self._has_surrogates():
+            return rutf8.FLAG_HAS_SURROGATES
+        return rutf8.FLAG_REGULAR
+
     def descr_add(self, space, w_other):
         try:
             w_other = self.convert_arg_to_w_unicode(space, w_other)
@@ -602,8 +645,9 @@
             if e.match(space, space.w_TypeError):
                 return space.w_NotImplemented
             raise
+        flag = self._combine_flags(self._get_flag(), w_other._get_flag())
         return W_UnicodeObject(self._utf8 + w_other._utf8,
-                               self._len() + w_other._len())
+                               self._len() + w_other._len(), flag)
 
     @jit.look_inside_iff(lambda self, space, list_w, size:
                          jit.loop_unrolling_heuristic(list_w, size))
@@ -613,6 +657,7 @@
 
         prealloc_size = len(value) * (size - 1)
         unwrapped = newlist_hint(size)
+        flag = self._get_flag()
         for i in range(size):
             w_s = list_w[i]
             check_item = self._join_check_item(space, w_s)
@@ -625,6 +670,7 @@
             # XXX Maybe the extra copy here is okay? It was basically going to
             #     happen anyway, what with being placed into the builder
             w_u = self.convert_arg_to_w_unicode(space, w_s)
+            flag = self._combine_flags(flag, w_u._get_flag())
             unwrapped.append(w_u._utf8)
             lgt += w_u._length
             prealloc_size += len(unwrapped[i])
@@ -634,7 +680,7 @@
             if value and i != 0:
                 sb.append(value)
             sb.append(unwrapped[i])
-        return W_UnicodeObject(sb.build(), lgt)
+        return W_UnicodeObject(sb.build(), lgt, flag)
 
     @unwrap_spec(keepends=bool)
     def descr_splitlines(self, space, keepends=False):
@@ -663,28 +709,33 @@
                     lgt += line_end_chars
             assert eol >= 0
             assert sol >= 0
-            strs_w.append(W_UnicodeObject(value[sol:eol], lgt))
+            # XXX we can do better with flags here, if we want to
+            strs_w.append(W_UnicodeObject(value[sol:eol], lgt, self._get_flag()))
         return space.newlist(strs_w)
 
     def descr_upper(self, space):
         value = self._utf8
         builder = StringBuilder(len(value))
+        flag = self._get_flag()
         i = 0
         while i < len(value):
             uchar = rutf8.codepoint_at_pos(value, i)
+            uchar = unicodedb.toupper(uchar)
+            if uchar >= 0x80:
+                flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
             i = rutf8.next_codepoint_pos(value, i)
-            rutf8.unichr_as_utf8_append(builder, unicodedb.toupper(uchar))
-        return W_UnicodeObject(builder.build(), self._length)
+            rutf8.unichr_as_utf8_append(builder, uchar)
+        return W_UnicodeObject(builder.build(), self._length, flag)
 
     @unwrap_spec(width=int)
     def descr_zfill(self, space, width):
         selfval = self._utf8
         if len(selfval) == 0:
-            return W_UnicodeObject('0' * width, width)
+            return W_UnicodeObject('0' * width, width, rutf8.FLAG_ASCII)
         num_zeros = width - self._len()
         if num_zeros <= 0:
             # cannot return self, in case it is a subclass of str
-            return W_UnicodeObject(selfval, self._len())
+            return W_UnicodeObject(selfval, self._len(), self._get_flag())
         builder = StringBuilder(num_zeros + len(selfval))
         if len(selfval) > 0 and (selfval[0] == '+' or selfval[0] == '-'):
             # copy sign to first position
@@ -694,7 +745,7 @@
             start = 0
         builder.append_multiple_char('0', num_zeros)
         builder.append_slice(selfval, start, len(selfval))
-        return W_UnicodeObject(builder.build(), width)
+        return W_UnicodeObject(builder.build(), width, self._get_flag())
 
     @unwrap_spec(maxsplit=int)
     def descr_split(self, space, w_sep=None, maxsplit=-1):
@@ -753,7 +804,7 @@
                 break
             i += 1
             byte_pos = self._index_to_byte(start + i * step)
-        return W_UnicodeObject(builder.build(), sl)
+        return W_UnicodeObject(builder.build(), sl, self._get_flag())
 
     def descr_getslice(self, space, w_start, w_stop):
         start, stop = normalize_simple_slice(
@@ -770,22 +821,30 @@
         assert stop >= 0
         byte_start = self._index_to_byte(start)
         byte_stop = self._index_to_byte(stop)
-        return W_UnicodeObject(self._utf8[byte_start:byte_stop], stop - start)
+        return W_UnicodeObject(self._utf8[byte_start:byte_stop], stop - start,
+                               self._get_flag())
 
     def descr_capitalize(self, space):
         value = self._utf8
         if len(value) == 0:
             return self._empty()
 
+        flag = self._get_flag()
         builder = StringBuilder(len(value))
         uchar = rutf8.codepoint_at_pos(value, 0)
         i = rutf8.next_codepoint_pos(value, 0)
-        rutf8.unichr_as_utf8_append(builder, unicodedb.toupper(uchar))
+        ch = unicodedb.toupper(uchar)
+        rutf8.unichr_as_utf8_append(builder, ch)
+        if ch >= 0x80:
+            flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
         while i < len(value):
             uchar = rutf8.codepoint_at_pos(value, i)
             i = rutf8.next_codepoint_pos(value, i)
-            rutf8.unichr_as_utf8_append(builder, unicodedb.tolower(uchar))
-        return W_UnicodeObject(builder.build(), self._len())
+            ch = unicodedb.tolower(uchar)
+            rutf8.unichr_as_utf8_append(builder, ch)
+            if ch >= 0x80:
+                flag = self._combine_flags(flag, rutf8.FLAG_REGULAR)
+        return W_UnicodeObject(builder.build(), self._len(), flag)
 
     @unwrap_spec(width=int, w_fillchar=WrappedDefault(' '))
     def descr_center(self, space, width, w_fillchar):
@@ -804,7 +863,7 @@
             centered = value
             d = 0
 
-        return W_UnicodeObject(centered, self._len() + d)
+        return W_UnicodeObject(centered, self._len() + d, self._get_flag())
 
     def descr_count(self, space, w_sub, w_start=None, w_end=None):
         value = self._utf8
@@ -830,11 +889,11 @@
         if pos < 0:
             return space.newtuple([self, self._empty(), self._empty()])
         else:
-            lgt = rutf8.check_utf8(value, True, stop=pos)
+            lgt, _ = rutf8.check_utf8(value, True, stop=pos)
             return space.newtuple(
-                [W_UnicodeObject(value[0:pos], lgt), w_sub,
+                [W_UnicodeObject(value[0:pos], lgt, self._get_flag()), w_sub,
                  W_UnicodeObject(value[pos + len(sub._utf8):len(value)],
-                    self._len() - lgt - sublen)])
+                    self._len() - lgt - sublen, self._get_flag())])
 
     def descr_rpartition(self, space, w_sub):
         value = self._utf8
@@ -848,11 +907,11 @@
         if pos < 0:
             return space.newtuple([self._empty(), self._empty(), self])
         else:
-            lgt = rutf8.check_utf8(value, True, stop=pos)
+            lgt, _ = rutf8.check_utf8(value, True, stop=pos)
             return space.newtuple(
-                [W_UnicodeObject(value[0:pos], lgt), w_sub,
+                [W_UnicodeObject(value[0:pos], lgt, self._get_flag()), w_sub,
                  W_UnicodeObject(value[pos + len(sub._utf8):len(value)],
-                    self._len() - lgt - sublen)])
+                    self._len() - lgt - sublen, self._get_flag())])
 
     @unwrap_spec(count=int)
     def descr_replace(self, space, w_old, w_new, count=-1):
@@ -870,8 +929,9 @@
         except OverflowError:
             raise oefmt(space.w_OverflowError, "replace string is too long")
 
+        flag = self._combine_flags(self._get_flag(), w_by._get_flag())
         newlength = self._length + replacements * (w_by._length - w_sub._length)
-        return W_UnicodeObject(res, newlength)
+        return W_UnicodeObject(res, newlength, flag)
 
     def descr_mul(self, space, w_times):
         try:
@@ -883,16 +943,29 @@
         if times <= 0:
             return self._empty()
         if len(self._utf8) == 1:
-            return W_UnicodeObject(self._utf8[0] * times, times)
-        return W_UnicodeObject(self._utf8 * times, times * self._len())
+            return W_UnicodeObject(self._utf8[0] * times, times,
+                                   self._get_flag())
+        return W_UnicodeObject(self._utf8 * times, times * self._len(),
+                               self._get_flag())
 
     descr_rmul = descr_mul
 
     def _get_index_storage(self):
-        storage = jit.conditional_call_elidable(self._index_storage,
-                    rutf8.create_utf8_index_storage, self._utf8, self._length)
+        # XXX write the correct jit.elidable
+        condition = (self._index_storage == rutf8.null_storage() or
+                     not bool(self._index_storage.contents))
+        if condition:
+            storage = rutf8.create_utf8_index_storage(self._utf8, self._length)
+        else:
+            storage = self._index_storage
         if not jit.isconstant(self):
+            prev_storage = self._index_storage
             self._index_storage = storage
+            if prev_storage == rutf8.UTF8_HAS_SURROGATES:
+                flag = rutf8.FLAG_HAS_SURROGATES
+            else:
+                flag = rutf8.FLAG_REGULAR
+            self._index_storage.flag = flag
         return storage
 
     def _getitem_result(self, space, index):
@@ -902,9 +975,19 @@
             raise oefmt(space.w_IndexError, "string index out of range")
         start = self._index_to_byte(index)
         end = rutf8.next_codepoint_pos(self._utf8, start)
-        return W_UnicodeObject(self._utf8[start:end], 1)
+        return W_UnicodeObject(self._utf8[start:end], 1, self._get_flag())
+
+    def _is_ascii(self):
+        return self._index_storage is rutf8.UTF8_IS_ASCII
+
+    def _has_surrogates(self):
+        return (self._index_storage is rutf8.UTF8_HAS_SURROGATES or
+                (bool(self._index_storage) and
+                 self._index_storage.flag == rutf8.FLAG_HAS_SURROGATES))
 
     def _index_to_byte(self, index):
+        if self._is_ascii():
+            return index
         return rutf8.codepoint_position_at_index(
             self._utf8, self._get_index_storage(), index)
 
@@ -967,6 +1050,7 @@
         if w_fillchar._len() != 1:
             raise oefmt(space.w_TypeError,
                         "rjust() argument 2 must be a single character")
+        flag = self._combine_flags(self._get_flag(), w_fillchar._get_flag())
         d = width - lgt
         if d > 0:
             if len(w_fillchar._utf8) == 1:
@@ -974,9 +1058,9 @@
                 value = d * w_fillchar._utf8[0] + value
             else:
                 value = d * w_fillchar._utf8 + value
-            return W_UnicodeObject(value, width)
+            return W_UnicodeObject(value, width, flag)
 
-        return W_UnicodeObject(value, lgt)
+        return W_UnicodeObject(value, lgt, flag)
 
     @unwrap_spec(width=int, w_fillchar=WrappedDefault(' '))
     def descr_ljust(self, space, width, w_fillchar):
@@ -985,6 +1069,7 @@
         if w_fillchar._len() != 1:
             raise oefmt(space.w_TypeError,
                         "ljust() argument 2 must be a single character")
+        flag = self._combine_flags(self._get_flag(), w_fillchar._get_flag())
         d = width - self._len()
         if d > 0:
             if len(w_fillchar._utf8) == 1:
@@ -992,9 +1077,9 @@
                 value = value + d * w_fillchar._utf8[0]
             else:
                 value = value + d * w_fillchar._utf8
-            return W_UnicodeObject(value, width)
+            return W_UnicodeObject(value, width, flag)
 
-        return W_UnicodeObject(value, self._len())
+        return W_UnicodeObject(value, self._len(), flag)
 
     def _utf8_sliced(self, start, stop, lgt):
         assert start >= 0
@@ -1002,7 +1087,7 @@
         #if start == 0 and stop == len(s) and space.is_w(space.type(orig_obj),
         #                                                space.w_bytes):
         #    return orig_obj
-        return W_UnicodeObject(self._utf8[start:stop], lgt)
+        return W_UnicodeObject(self._utf8[start:stop], lgt, self._get_flag())
 
     def _strip_none(self, space, left, right):
         "internal function called by str_xstrip methods"
@@ -1050,7 +1135,7 @@
         return self._utf8_sliced(lpos, rpos, lgt)
 
     def descr_getnewargs(self, space):
-        return space.newtuple([W_UnicodeObject(self._utf8, self._length)])
+        return space.newtuple([W_UnicodeObject(self._utf8, self._length, self._get_flag())])
 
 
 
@@ -1135,11 +1220,11 @@
         if encoding == 'ascii':
             s = space.charbuf_w(w_obj)
             unicodehelper.check_ascii_or_raise(space, s)
-            return space.newutf8(s, len(s))
+            return space.newutf8(s, len(s), rutf8.FLAG_ASCII)
         if encoding == 'utf-8':
             s = space.charbuf_w(w_obj)
-            lgt = unicodehelper.check_utf8_or_raise(space, s)
-            return space.newutf8(s, lgt)
+            lgt, flag = unicodehelper.check_utf8_or_raise(space, s)
+            return space.newutf8(s, lgt, flag)
     w_codecs = space.getbuiltinmodule("_codecs")
     w_decode = space.getattr(w_codecs, space.newtext("decode"))
     if errors is None:
@@ -1194,7 +1279,7 @@
         return unicode_from_encoded_object(space, w_bytes, encoding, "strict")
     s = space.bytes_w(w_bytes)
     unicodehelper.check_ascii_or_raise(space, s)
-    return W_UnicodeObject(s, len(s))
+    return W_UnicodeObject(s, len(s), rutf8.FLAG_ASCII)
 
 
 class UnicodeDocstrings:
@@ -1741,7 +1826,7 @@
     return [s for s in value]
 
 
-W_UnicodeObject.EMPTY = W_UnicodeObject('', 0)
+W_UnicodeObject.EMPTY = W_UnicodeObject('', 0, rutf8.FLAG_ASCII)
 
 
 # Helper for converting int/long
diff --git a/rpython/rlib/rutf8.py b/rpython/rlib/rutf8.py
--- a/rpython/rlib/rutf8.py
+++ b/rpython/rlib/rutf8.py
@@ -305,14 +305,14 @@
 def check_utf8(s, allow_surrogates, start=0, stop=-1):
     """Check that 's' is a utf-8-encoded byte string.
 
-    Returns the length (number of chars) and flags or raise CheckError.
+    Returns the length (number of chars) and flag or raise CheckError.
     If allow_surrogates is False, then also raise if we see any.
     Note also codepoints_in_utf8(), which also computes the length
     faster by assuming that 's' is valid utf-8.
     """
-    res, flags = _check_utf8(s, allow_surrogates, start, stop)
+    res, flag = _check_utf8(s, allow_surrogates, start, stop)
     if res >= 0:
-        return res, flags
+        return res, flag
     raise CheckError(~res)
 
 @jit.elidable
@@ -416,12 +416,13 @@
     return False
 
 
-UTF8_INDEX_STORAGE = lltype.GcArray(lltype.Struct(
-    'utf8_loc',
+UTF8_INDEX_STORAGE = lltype.GcStruct('utf8_loc',
+    ('flag', lltype.Signed),
+    ('contents', lltype.Ptr(lltype.GcArray(lltype.Struct(
+    'utf8_loc_elem',
     ('baseindex', lltype.Signed),
-    ('flag', lltype.Signed),
-    ('ofs', lltype.FixedSizeArray(lltype.Char, 16))
-    ))
+    ('ofs', lltype.FixedSizeArray(lltype.Char, 16)))
+    ))))
 
 FLAG_REGULAR = 0
 FLAG_HAS_SURROGATES = 1
@@ -429,43 +430,47 @@
 # note that we never need index storage if we're pure ascii, but it's useful
 # for passing into W_UnicodeObject.__init__
 
-ASCII_INDEX_STORAGE_BLOCKS = 5
-ASCII_INDEX_STORAGE = lltype.malloc(UTF8_INDEX_STORAGE,
-                                    ASCII_INDEX_STORAGE_BLOCKS,
-                                    immortal=True)
-for _i in range(ASCII_INDEX_STORAGE_BLOCKS):
-    ASCII_INDEX_STORAGE[_i].baseindex = _i * 64
-    for _j in range(16):
-        ASCII_INDEX_STORAGE[_i].ofs[_j] = chr(_j * 4 + 1)
+#ASCII_INDEX_STORAGE_BLOCKS = 5
+#ASCII_INDEX_STORAGE = lltype.malloc(UTF8_INDEX_STORAGE.contents.TO,
+#                                    ASCII_INDEX_STORAGE_BLOCKS,
+#                                    immortal=True)
+#for _i in range(ASCII_INDEX_STORAGE_BLOCKS):
+#    ASCII_INDEX_STORAGE[_i].baseindex = _i * 64
+#    for _j in range(16):
+#        ASCII_INDEX_STORAGE[_i].ofs[_j] = chr(_j * 4 + 1)
 
 def null_storage():
     return lltype.nullptr(UTF8_INDEX_STORAGE)
 
-UTF8_IS_ASCII = lltype.malloc(UTF8_INDEX_STORAGE, 0, immortal=True)
-UTF8_HAS_SURROGATES = lltype.malloc(UTF8_INDEX_STORAGE, 0, immortal=True)
+UTF8_IS_ASCII = lltype.malloc(UTF8_INDEX_STORAGE, immortal=True)
+UTF8_IS_ASCII.contents = lltype.nullptr(UTF8_INDEX_STORAGE.contents.TO)
+UTF8_HAS_SURROGATES = lltype.malloc(UTF8_INDEX_STORAGE, immortal=True)
+UTF8_HAS_SURROGATES.contents = lltype.nullptr(UTF8_INDEX_STORAGE.contents.TO)
 
 def create_utf8_index_storage(utf8, utf8len):
     """ Create an index storage which stores index of each 4th character
     in utf8 encoded unicode string.
     """
-    if len(utf8) == utf8len < ASCII_INDEX_STORAGE_BLOCKS * 64:
-        return ASCII_INDEX_STORAGE
+#    if len(utf8) == utf8len < ASCII_INDEX_STORAGE_BLOCKS * 64:
+#        return ASCII_INDEX_STORAGE
     arraysize = utf8len // 64 + 1
-    storage = lltype.malloc(UTF8_INDEX_STORAGE, arraysize)
+    storage = lltype.malloc(UTF8_INDEX_STORAGE)
+    contents = lltype.malloc(UTF8_INDEX_STORAGE.contents.TO, arraysize)
+    storage.contents = contents
     baseindex = 0
     current = 0
     while True:
-        storage[current].baseindex = baseindex
+        contents[current].baseindex = baseindex
         next = baseindex
         for i in range(16):
             if utf8len == 0:
                 next += 1      # assume there is an extra '\x00' character
             else:
                 next = next_codepoint_pos(utf8, next)
-            storage[current].ofs[i] = chr(next - baseindex)
+            contents[current].ofs[i] = chr(next - baseindex)
             utf8len -= 4
             if utf8len < 0:
-                assert current + 1 == len(storage)
+                assert current + 1 == len(contents)
                 break
             next = next_codepoint_pos(utf8, next)
             next = next_codepoint_pos(utf8, next)
@@ -485,8 +490,8 @@
     this function.
     """
     current = index >> 6
-    ofs = ord(storage[current].ofs[(index >> 2) & 0x0F])
-    bytepos = storage[current].baseindex + ofs
+    ofs = ord(storage.contents[current].ofs[(index >> 2) & 0x0F])
+    bytepos = storage.contents[current].baseindex + ofs
     index &= 0x3
     if index == 0:
         return prev_codepoint_pos(utf8, bytepos)
@@ -504,8 +509,8 @@
     storage of type UTF8_INDEX_STORAGE
     """
     current = index >> 6
-    ofs = ord(storage[current].ofs[(index >> 2) & 0x0F])
-    bytepos = storage[current].baseindex + ofs
+    ofs = ord(storage.contents[current].ofs[(index >> 2) & 0x0F])
+    bytepos = storage.contents[current].baseindex + ofs
     index &= 0x3
     if index == 0:
         return codepoint_before_pos(utf8, bytepos)


More information about the pypy-commit mailing list