[pypy-commit] pypy union-side-effects: Fix str vs unicode correctness issues

rlamy pypy.commits at gmail.com
Sun Sep 4 22:25:01 EDT 2016


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: union-side-effects
Changeset: r86874:350957bdef41
Date: 2016-09-05 03:24 +0100
http://bitbucket.org/pypy/pypy/changeset/350957bdef41/

Log:	Fix str vs unicode correctness issues

diff --git a/pypy/module/_io/interp_textio.py b/pypy/module/_io/interp_textio.py
--- a/pypy/module/_io/interp_textio.py
+++ b/pypy/module/_io/interp_textio.py
@@ -231,16 +231,16 @@
             while True:
                 # Fast path for non-control chars. The loop always ends
                 # since the Py_UNICODE storage is NUL-terminated.
-                while i < size and line[start + i] > '\r':
+                while i < size and line[start + i] > u'\r':
                     i += 1
                 if i >= size:
                     return -1, size
                 ch = line[start + i]
                 i += 1
-                if ch == '\n':
+                if ch == u'\n':
                     return i, 0
-                if ch == '\r':
-                    if line[start + i] == '\n':
+                if ch == u'\r':
+                    if line[start + i] == u'\n':
                         return i + 1, 0
                     else:
                         return i, 0
diff --git a/pypy/objspace/std/formatting.py b/pypy/objspace/std/formatting.py
--- a/pypy/objspace/std/formatting.py
+++ b/pypy/objspace/std/formatting.py
@@ -184,11 +184,11 @@
                 except IndexError:
                     space = self.space
                     raise oefmt(space.w_ValueError, "incomplete format key")
-                if c == ')':
+                if c == const(')'):
                     pcount -= 1
                     if pcount == 0:
                         break
-                elif c == '(':
+                elif c == const('('):
                     pcount += 1
                 i += 1
             self.fmtpos = i + 1   # first character after ')'
@@ -203,7 +203,7 @@
             return space.getitem(self.w_valuedict, w_key)
 
         def parse_fmt(self):
-            if self.peekchr() == '(':
+            if self.peekchr() == const('('):
                 w_value = self.getmappingvalue(self.getmappingkey())
             else:
                 w_value = None
@@ -216,7 +216,7 @@
                 self.f_ljust = True
                 self.width = -self.width
 
-            if self.peekchr() == '.':
+            if self.peekchr() == const('.'):
                 self.forward()
                 self.prec = self.peel_num('prec', INT_MAX)
                 if self.prec < 0:
@@ -225,7 +225,7 @@
                 self.prec = -1
 
             c = self.peekchr()
-            if c == 'h' or c == 'l' or c == 'L':
+            if c == const('h') or c == const('l') or c == const('L'):
                 self.forward()
 
             return w_value
@@ -240,15 +240,15 @@
             self.f_zero  = False
             while True:
                 c = self.peekchr()
-                if c == '-':
+                if c == const('-'):
                     self.f_ljust = True
-                elif c == '+':
+                elif c == const('+'):
                     self.f_sign = True
-                elif c == ' ':
+                elif c == const(' '):
                     self.f_blank = True
-                elif c == '#':
+                elif c == const('#'):
                     self.f_alt = True
-                elif c == '0':
+                elif c == const('0'):
                     self.f_zero = True
                 else:
                     break
@@ -259,7 +259,7 @@
         def peel_num(self, name, maxval):
             space = self.space
             c = self.peekchr()
-            if c == '*':
+            if c == const('*'):
                 self.forward()
                 w_value = self.nextinputvalue()
                 if name == 'width':
@@ -293,7 +293,7 @@
                 fmt = self.fmt
                 i = i0 = self.fmtpos
                 while i < len(fmt):
-                    if fmt[i] == '%':
+                    if fmt[i] == const('%'):
                         break
                     i += 1
                 else:
@@ -306,7 +306,7 @@
                 w_value = self.parse_fmt()
                 c = self.peekchr()
                 self.forward()
-                if c == '%':
+                if c == const('%'):
                     self.std_wp(const('%'))
                     continue
                 if w_value is None:
@@ -315,7 +315,7 @@
                 # dispatch on the formatter
                 # (this turns into a switch after translation)
                 for c1 in FORMATTER_CHARS:
-                    if c == c1:
+                    if c == const(c1):
                         # 'c1' is an annotation constant here,
                         # so this getattr() is ok
                         do_fmt = getattr(self, 'fmt_' + c1)
diff --git a/pypy/objspace/std/newformat.py b/pypy/objspace/std/newformat.py
--- a/pypy/objspace/std/newformat.py
+++ b/pypy/objspace/std/newformat.py
@@ -45,6 +45,7 @@
 
 
 def make_template_formatting_class(for_unicode):
+    STR = unicode if for_unicode else str
     class TemplateFormatter(object):
         is_unicode = for_unicode
 
@@ -90,19 +91,19 @@
             while i < end:
                 c = s[i]
                 i += 1
-                if c == "{" or c == "}":
+                if c == STR("{") or c == STR("}"):
                     at_end = i == end
                     # Find escaped "{" and "}"
                     markup_follows = True
-                    if c == "}":
-                        if at_end or s[i] != "}":
+                    if c == STR("}"):
+                        if at_end or s[i] != STR("}"):
                             raise oefmt(space.w_ValueError, "Single '}'")
                         i += 1
                         markup_follows = False
-                    if c == "{":
+                    if c == STR("{"):
                         if at_end:
                             raise oefmt(space.w_ValueError, "Single '{'")
-                        if s[i] == "{":
+                        if s[i] == STR("{"):
                             i += 1
                             markup_follows = False
                     # Attach literal data, ending with { or }
@@ -124,10 +125,10 @@
                     recursive = False
                     while i < end:
                         c = s[i]
-                        if c == "{":
+                        if c == STR("{"):
                             recursive = True
                             nested += 1
-                        elif c == "}":
+                        elif c == STR("}"):
                             nested -= 1
                             if not nested:
                                 break
@@ -150,9 +151,9 @@
             i = start
             while i < end:
                 c = s[i]
-                if c == ":" or c == "!":
+                if c == STR(":") or c == STR("!"):
                     end_name = i
-                    if c == "!":
+                    if c == STR("!"):
                         i += 1
                         if i == end:
                             raise oefmt(self.space.w_ValueError,
@@ -160,7 +161,7 @@
                         conversion = s[i]
                         i += 1
                         if i < end:
-                            if s[i] != ':':
+                            if s[i] != STR(':'):
                                 raise oefmt(self.space.w_ValueError,
                                             "expected ':' after format "
                                             "specifier")
@@ -180,7 +181,7 @@
             end = len(name)
             while i < end:
                 c = name[i]
-                if c == "[" or c == ".":
+                if c == STR("[") or c == STR("."):
                     break
                 i += 1
             empty = not i
@@ -232,12 +233,12 @@
             i = start
             while i < end:
                 c = name[i]
-                if c == ".":
+                if c == STR("."):
                     i += 1
                     start = i
                     while i < end:
                         c = name[i]
-                        if c == "[" or c == ".":
+                        if c == STR("[") or c == STR("."):
                             break
                         i += 1
                     if start == i:
@@ -249,13 +250,13 @@
                     else:
                         self.parser_list_w.append(space.newtuple([
                             space.w_True, w_attr]))
-                elif c == "[":
+                elif c == STR("["):
                     got_bracket = False
                     i += 1
                     start = i
                     while i < end:
                         c = name[i]
-                        if c == "]":
+                        if c == STR("]"):
                             got_bracket = True
                             break
                         i += 1
@@ -284,7 +285,7 @@
             end = len(name)
             while i < end:
                 c = name[i]
-                if c == "[" or c == ".":
+                if c == STR("[") or c == STR("."):
                     break
                 i += 1
             if i == 0:
@@ -307,9 +308,9 @@
         def _convert(self, w_obj, conversion):
             space = self.space
             conv = conversion[0]
-            if conv == "r":
+            if conv == STR("r"):
                 return space.repr(w_obj)
-            elif conv == "s":
+            elif conv == STR("s"):
                 if self.is_unicode:
                     return space.call_function(space.w_unicode, w_obj)
                 return space.str(w_obj)
@@ -400,6 +401,7 @@
 LONG_DIGITS = string.digits + string.ascii_lowercase
 
 def make_formatting_class(for_unicode):
+    _lit = unicode if for_unicode else str
     class Formatter(BaseFormatter):
         """__format__ implementation for builtin types."""
 
@@ -412,22 +414,22 @@
             self.spec = spec
 
         def _is_alignment(self, c):
-            return (c == "<" or
-                    c == ">" or
-                    c == "=" or
-                    c == "^")
+            return (c == _lit("<") or
+                    c == _lit(">") or
+                    c == _lit("=") or
+                    c == _lit("^"))
 
         def _is_sign(self, c):
-            return (c == " " or
-                    c == "+" or
-                    c == "-")
+            return (c == _lit(" ") or
+                    c == _lit("+") or
+                    c == _lit("-"))
 
         def _parse_spec(self, default_type, default_align):
             space = self.space
             self._fill_char = self._lit(" ")[0]
             self._align = default_align
             self._alternate = False
-            self._sign = "\0"
+            self._sign = _lit("\0")
             self._thousands_sep = False
             self._precision = -1
             the_type = default_type
@@ -451,19 +453,19 @@
             if length - i >= 1 and self._is_sign(spec[i]):
                 self._sign = spec[i]
                 i += 1
-            if length - i >= 1 and spec[i] == "#":
+            if length - i >= 1 and spec[i] == _lit("#"):
                 self._alternate = True
                 i += 1
-            if not got_fill_char and length - i >= 1 and spec[i] == "0":
+            if not got_fill_char and length - i >= 1 and spec[i] == _lit("0"):
                 self._fill_char = self._lit("0")[0]
                 if not got_align:
-                    self._align = "="
+                    self._align = _lit("=")
                 i += 1
             self._width, i = _parse_int(self.space, spec, i, length)
-            if length != i and spec[i] == ",":
+            if length != i and spec[i] == _lit(","):
                 self._thousands_sep = True
                 i += 1
-            if length != i and spec[i] == ".":
+            if length != i and spec[i] == _lit("."):
                 i += 1
                 self._precision, i = _parse_int(self.space, spec, i, length)
                 if self._precision == -1:
@@ -471,28 +473,20 @@
             if length - i > 1:
                 raise oefmt(space.w_ValueError, "invalid format spec")
             if length - i == 1:
-                presentation_type = spec[i]
-                if self.is_unicode:
-                    try:
-                        the_type = spec[i].encode("ascii")[0]
-                    except UnicodeEncodeError:
-                        raise oefmt(space.w_ValueError,
-                                    "invalid presentation type")
-                else:
-                    the_type = presentation_type
+                the_type = spec[i]
                 i += 1
             self._type = the_type
             if self._thousands_sep:
                 tp = self._type
-                if (tp == "d" or
-                    tp == "e" or
-                    tp == "f" or
-                    tp == "g" or
-                    tp == "E" or
-                    tp == "G" or
-                    tp == "%" or
-                    tp == "F" or
-                    tp == "\0"):
+                if (tp == _lit("d") or
+                    tp == _lit("e") or
+                    tp == _lit("f") or
+                    tp == _lit("g") or
+                    tp == _lit("E") or
+                    tp == _lit("G") or
+                    tp == _lit("%") or
+                    tp == _lit("F") or
+                    tp == _lit("\0")):
                     # ok
                     pass
                 else:
@@ -506,11 +500,11 @@
             else:
                 total = length
             align = self._align
-            if align == ">":
+            if align == _lit(">"):
                 left = total - length
-            elif align == "^":
+            elif align == _lit("^"):
                 left = (total - length) / 2
-            elif align == "<" or align == "=":
+            elif align == _lit("<") or align == _lit("="):
                 left = 0
             else:
                 raise AssertionError("shouldn't be here")
@@ -539,23 +533,24 @@
                 return rstring.StringBuilder()
 
         def _unknown_presentation(self, tp):
+            spec = self._type.encode('ascii') if for_unicode else self._type
             raise oefmt(self.space.w_ValueError,
-                        "unknown presentation for %s: '%s'", tp, self._type)
+                        "unknown presentation for %s: '%s'", tp, spec)
 
         def format_string(self, string):
             space = self.space
-            if self._parse_spec("s", "<"):
+            if self._parse_spec(_lit("s"), _lit("<")):
                 return space.wrap(string)
-            if self._type != "s":
+            if self._type != _lit("s"):
                 self._unknown_presentation("string")
-            if self._sign != "\0":
+            if self._sign != _lit("\0"):
                 raise oefmt(space.w_ValueError,
                             "Sign not allowed in string format specifier")
             if self._alternate:
                 raise oefmt(space.w_ValueError,
                             "Alternate form (#) not allowed in string format "
                             "specifier")
-            if self._align == "=":
+            if self._align == _lit("="):
                 raise oefmt(space.w_ValueError,
                             "'=' alignment not allowed in string format "
                             "specifier")
@@ -760,8 +755,8 @@
                             "precision not allowed in integer type")
             sign_char = "\0"
             tp = self._type
-            if tp == "c":
-                if self._sign != "\0":
+            if tp == _lit("c"):
+                if self._sign != _lit("\0"):
                     raise oefmt(space.w_ValueError,
                                 "sign not allowed with 'c' presentation type")
                 value = space.int_w(w_num)
@@ -776,16 +771,16 @@
                 to_prefix = 0
                 to_numeric = 0
             else:
-                if tp == "b":
+                if tp == _lit("b"):
                     base = 2
                     skip_leading = 2
-                elif tp == "o":
+                elif tp == _lit("o"):
                     base = 8
                     skip_leading = 2
-                elif tp == "x" or tp == "X":
+                elif tp == _lit("x") or tp == _lit("X"):
                     base = 16
                     skip_leading = 2
-                elif tp == "n" or tp == "d":
+                elif tp == _lit("n") or tp == _lit("d"):
                     base = 10
                     skip_leading = 0
                 else:
@@ -808,7 +803,7 @@
             spec = self._calc_num_width(n_prefix, sign_char, to_numeric, n_digits,
                                         n_remainder, False, result)
             fill = self._fill_char
-            upper = self._type == "X"
+            upper = self._type == _lit("X")
             return self.space.wrap(self._fill_number(spec, result, to_numeric,
                                      to_prefix, fill, to_remainder, upper))
 
@@ -874,7 +869,7 @@
 
         def format_int_or_long(self, w_num, kind):
             space = self.space
-            if self._parse_spec("d", ">"):
+            if self._parse_spec(_lit("d"), _lit(">")):
                 if self.is_unicode:
                     return space.call_function(space.w_unicode, w_num)
                 return self.space.str(w_num)
@@ -922,15 +917,15 @@
                             "Alternate form (#) not allowed in float formats")
             tp = self._type
             self._get_locale(tp)
-            if tp == "\0":
-                tp = "g"
+            if tp == _lit("\0"):
+                tp = _lit("g")
                 default_precision = 12
                 flags |= rfloat.DTSF_ADD_DOT_0
-            elif tp == "n":
-                tp = "g"
+            elif tp == _lit("n"):
+                tp = _lit("g")
             value = space.float_w(w_float)
-            if tp == "%":
-                tp = "f"
+            if tp == _lit("%"):
+                tp = _lit("f")
                 value *= 100
                 add_pct = True
             else:
@@ -963,7 +958,7 @@
 
         def format_float(self, w_float):
             space = self.space
-            if self._parse_spec("\0", ">"):
+            if self._parse_spec(_lit("\0"), _lit(">")):
                 if self.is_unicode:
                     return space.call_function(space.w_unicode, w_float)
                 return space.str(w_float)
@@ -1002,9 +997,9 @@
                             "specifier")
             skip_re = 0
             add_parens = 0
-            if tp == "\0":
+            if tp == _lit("\0"):
                 #should mirror str() output
-                tp = "g"
+                tp = _lit("g")
                 default_precision = 12
                 #test if real part is non-zero
                 if (w_complex.realval == 0 and
@@ -1013,9 +1008,9 @@
                 else:
                     add_parens = 1
 
-            if tp == "n":
+            if tp == _lit("n"):
                 #same as 'g' except for locale, taken care of later
-                tp = "g"
+                tp = _lit("g")
 
             #check if precision not set
             if self._precision == -1:
@@ -1073,7 +1068,7 @@
             #self._grouped_digits will get overwritten in imaginary calc_num_width
             re_grouped_digits = self._grouped_digits
             if not skip_re:
-                self._sign = "+"
+                self._sign = _lit("+")
             im_spec = self._calc_num_width(0, im_sign, to_imag_number, n_im_digits,
                                            im_n_remainder, im_have_dec,
                                            im_num)
@@ -1127,7 +1122,7 @@
             """return the string representation of a complex number"""
             space = self.space
             #parse format specification, set associated variables
-            if self._parse_spec("\0", ">"):
+            if self._parse_spec(_lit("\0"), _lit(">")):
                 return space.str(w_complex)
             tp = self._type
             if (tp == "\0" or
@@ -1148,9 +1143,5 @@
 
 @specialize.arg(2)
 def run_formatter(space, w_format_spec, meth, *args):
-    if space.isinstance_w(w_format_spec, space.w_unicode):
-        formatter = unicode_formatter(space, space.unicode_w(w_format_spec))
-        return getattr(formatter, meth)(*args)
-    else:
-        formatter = str_formatter(space, space.str_w(w_format_spec))
-        return getattr(formatter, meth)(*args)
+    formatter = str_formatter(space, space.str_w(w_format_spec))
+    return getattr(formatter, meth)(*args)
diff --git a/pypy/objspace/std/stringmethods.py b/pypy/objspace/std/stringmethods.py
--- a/pypy/objspace/std/stringmethods.py
+++ b/pypy/objspace/std/stringmethods.py
@@ -220,7 +220,7 @@
             offset = len(token)
 
             while 1:
-                if token[offset-1] == "\n" or token[offset-1] == "\r":
+                if token[offset-1] == self._chr("\n") or token[offset-1] == self._chr("\r"):
                     break
                 distance += 1
                 offset -= 1
@@ -598,7 +598,8 @@
             eol = pos
             pos += 1
             # read CRLF as one line break
-            if pos < length and value[eol] == '\r' and value[pos] == '\n':
+            if (pos < length and value[eol] == self._chr('\r')
+                    and value[pos] == self._chr('\n')):
                 pos += 1
             if keepends:
                 eol = pos
@@ -780,7 +781,8 @@
             return self._new(selfval)
 
         builder = self._builder(width)
-        if len(selfval) > 0 and (selfval[0] == '+' or selfval[0] == '-'):
+        if len(selfval) > 0 and (
+                selfval[0] == self._chr('+') or selfval[0] == self._chr('-')):
             # copy sign to first position
             builder.append(selfval[0])
             start = 1


More information about the pypy-commit mailing list