[Jython-checkins] jython: Accept unicode arguments at a csv.writer (fixes #2632).

jeff.allen jython-checkins at python.org
Tue Nov 21 17:39:12 EST 2017


https://hg.python.org/jython/rev/08978c4d1ab0
changeset:   8140:08978c4d1ab0
user:        Jeff Allen <ja.py at farowl.co.uk>
date:        Tue Nov 21 19:37:02 2017 +0000
summary:
  Accept unicode arguments at a csv.writer (fixes #2632).

The CPython csv.writer accepts unicode strings and encodes them using
the current default encoding. This is not documented, but we can easily
reproduce the behaviour, which is relied on by some users. A simple
test_csv_jy is added for UTF-8 default. We hide sys.setdefaultencoding
again after use since this otherwise causes test_site to fail. The same
fault is corrected, where it was latent in test_unicode_jy.

files:
  Lib/test/test_csv_jy.py                    |  96 ++++++++++
  Lib/test/test_unicode_jy.py                |   8 +-
  src/org/python/modules/_csv/PyDialect.java |  33 +-
  src/org/python/modules/_csv/PyWriter.java  |  48 ++--
  4 files changed, 145 insertions(+), 40 deletions(-)


diff --git a/Lib/test/test_csv_jy.py b/Lib/test/test_csv_jy.py
new file mode 100644
--- /dev/null
+++ b/Lib/test/test_csv_jy.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2017 Jython Developers
+
+# Additional csv module unit tests for Jython
+
+import csv
+import io
+import sys
+from tempfile import TemporaryFile
+from test import test_support
+import unittest
+
+# This test has been adapted from Python 3 test_csv.TestUnicode. In Python 3,
+# the csv module supports Unicode directly. In Python 2, it does not, except
+# that it is transparent to byte data. Many tools, however, accept UTF-8
+# encoded text in a CSV file.
+#
+class EncodingContext(object):
+    """Context manager to save and restore the encoding.
+
+    Use like this:
+
+        with EncodingContext("utf-8"):
+            self.assertEqual("'caf\xc3\xa9'", u"'caf\xe9'")
+    """
+
+    def __init__(self, encoding):
+        if not hasattr(sys, "setdefaultencoding"):
+            reload(sys)
+        self.original_encoding = sys.getdefaultencoding()
+        sys.setdefaultencoding(encoding)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *ignore_exc):
+        sys.setdefaultencoding(self.original_encoding)
+
+class TestUnicode(unittest.TestCase):
+
+    names = [u"Martin von Löwis",
+             u"Marc André Lemburg",
+             u"Guido van Rossum",
+             u"François Pinard",
+             u"稲田直樹"]
+
+    def test_decode_read(self):
+        # The user code receives byte data and takes care of the decoding
+        with TemporaryFile("w+b") as fileobj:
+            line = u",".join(self.names) + u"\r\n"
+            fileobj.write(line.encode('utf-8'))
+            fileobj.seek(0)
+            reader = csv.reader(fileobj)
+            # The reader yields rows of byte strings that decode to the data
+            table = [[e.decode('utf-8') for e in row] for row in reader]
+            self.assertEqual(table, [self.names])
+
+    def test_encode_write(self):
+        # The user encodes unicode objects to byte data that csv writes
+        with TemporaryFile("w+b") as fileobj:
+            writer = csv.writer(fileobj)
+            # We present a row of encoded strings to the writer
+            writer.writerow([n.encode('utf-8') for n in self.names])
+            # We expect the file contents to be the UTF-8 of the csv data
+            expected = u",".join(self.names) + u"\r\n"
+            fileobj.seek(0)
+            self.assertEqual(fileobj.read().decode('utf-8'), expected)
+
+    def test_unicode_write(self):
+        # The user supplies unicode data that csv.writer default-encodes
+        # (undocumented feature relied upon by client code).
+        # See Issue #2632  https://github.com/jythontools/jython/issues/90
+        with TemporaryFile("w+b") as fileobj:
+            with EncodingContext('utf-8'):
+                writer = csv.writer(fileobj)
+                # We present a row of unicode strings to the writer
+                writer.writerow(self.names)
+                # We expect the file contents to be the UTF-8 of the csv data
+                expected = u",".join(self.names) + u"\r\n"
+                fileobj.seek(0)
+                self.assertEqual(fileobj.read().decode(), expected)
+
+
+def test_main():
+    # We'll be enabling sys.setdefaultencoding so remember to disable
+    had_set = hasattr(sys, "setdefaultencoding")
+    try:
+        test_support.run_unittest(
+            TestUnicode,
+        )
+    finally:
+        if not had_set:
+            delattr(sys, "setdefaultencoding")
+
+if __name__ == "__main__":
+    test_main()
diff --git a/Lib/test/test_unicode_jy.py b/Lib/test/test_unicode_jy.py
--- a/Lib/test/test_unicode_jy.py
+++ b/Lib/test/test_unicode_jy.py
@@ -1341,7 +1341,10 @@
 
 
 def test_main():
-    test_support.run_unittest(
+    # We'll be enabling sys.setdefaultencoding so remember to disable
+    had_set = hasattr(sys, "setdefaultencoding")
+    try:
+        test_support.run_unittest(
                 UnicodeTestCase,
                 UnicodeIndexMixTest,
                 UnicodeFormatTestCase,
@@ -1353,6 +1356,9 @@
                 DefaultDecodingUTF8,
                 DefaultDecodingCp850,
             )
+    finally:
+        if not had_set:
+            delattr(sys, "setdefaultencoding")
 
 
 if __name__ == "__main__":
diff --git a/src/org/python/modules/_csv/PyDialect.java b/src/org/python/modules/_csv/PyDialect.java
--- a/src/org/python/modules/_csv/PyDialect.java
+++ b/src/org/python/modules/_csv/PyDialect.java
@@ -1,4 +1,4 @@
-/* Copyright (c) Jython Developers */
+/* Copyright (c)2017 Jython Developers */
 package org.python.modules._csv;
 
 import org.python.core.ArgParser;
@@ -9,6 +9,7 @@
 import org.python.core.PyObject;
 import org.python.core.PyString;
 import org.python.core.PyType;
+import org.python.core.PyUnicode;
 import org.python.core.Untraversable;
 import org.python.expose.ExposedDelete;
 import org.python.expose.ExposedGet;
@@ -153,17 +154,21 @@
     private static char toChar(String name, PyObject src, char dflt) {
         if (src == null) {
             return dflt;
-        }
-        boolean isStr = Py.isInstance(src, PyString.TYPE);
-        if (src == Py.None || isStr && src.__len__() == 0) {
+        } else if (src == Py.None) {
             return '\0';
-        } else if (!isStr || src.__len__() != 1) {
-            throw Py.TypeError(String.format("\"%s\" must be an 1-character string", name));
+        } else if (src instanceof PyString) {
+            String s = (src instanceof PyUnicode) ? ((PyUnicode) src).encode() : src.toString();
+            if (s.length() == 0) {
+                return '\0';
+            } else if (s.length() == 1) {
+                return s.charAt(0);
+            }
         }
-        return src.toString().charAt(0);
+        // This is only going to work for BMP strings because of the char return type
+        throw Py.TypeError(String.format("\"%s\" must be a 1-character string", name));
     }
 
-    private static int toInt(String name, PyObject src, int dflt) {
+       private static int toInt(String name, PyObject src, int dflt) {
         if (src == null) {
             return dflt;
         }
@@ -176,14 +181,14 @@
     private static String toStr(String name, PyObject src, String dflt) {
         if (src == null) {
             return dflt;
-        }
-        if (src == Py.None) {
+        } else if (src == Py.None) {
             return null;
+        } else if (src instanceof PyUnicode) {
+            return ((PyUnicode) src).encode().toString();
+        } else if (src instanceof PyString) {
+            return src.toString();
         }
-        if (!(src instanceof PyBaseString)) {
-            throw Py.TypeError(String.format("\"%s\" must be an string", name));
-        }
-        return src.toString();
+        throw Py.TypeError(String.format("\"%s\" must be a string", name));
     }
 
     @ExposedGet(name = "escapechar")
diff --git a/src/org/python/modules/_csv/PyWriter.java b/src/org/python/modules/_csv/PyWriter.java
--- a/src/org/python/modules/_csv/PyWriter.java
+++ b/src/org/python/modules/_csv/PyWriter.java
@@ -1,4 +1,4 @@
-/* Copyright (c) Jython Developers */
+/* Copyright (c)2017 Jython Developers */
 package org.python.modules._csv;
 
 import org.python.core.Py;
@@ -7,6 +7,7 @@
 import org.python.core.PyObject;
 import org.python.core.PyString;
 import org.python.core.PyType;
+import org.python.core.PyUnicode;
 import org.python.core.Traverseproc;
 import org.python.core.Visitproc;
 import org.python.expose.ExposedType;
@@ -21,11 +22,9 @@
 @ExposedType(name = "_csv.writer", doc = PyWriter.writer_doc)
 public class PyWriter extends PyObject implements Traverseproc {
 
-    public static final String writer_doc =
-    "CSV writer\n" +
-    "\n" +
-    "Writer objects are responsible for generating tabular data\n" +
-    "in CSV format from sequence input.\n";
+    public static final String writer_doc = "CSV writer\n\n"//
+            + "Writer objects are responsible for generating tabular data\n"
+            + "in CSV format from sequence input.\n";
 
     public static final PyType TYPE = PyType.fromClass(PyWriter.class);
 
@@ -53,11 +52,10 @@
         this.dialect = dialect;
     }
 
-    public static PyString __doc__writerows = Py.newString(
-            "writerows(sequence of sequences)\n" +
-            "\n" +
-            "Construct and write a series of sequences to a csv file.  Non-string\n" +
-            "elements will be converted to string.");
+    public static PyString __doc__writerows = Py.newString(//
+            "writerows(sequence of sequences)\n\n"
+            + "Construct and write a series of sequences to a csv file.  Non-string\n"
+            + "elements will be converted to string.");
 
     public void writerows(PyObject seqseq) {
         writer_writerows(seqseq);
@@ -82,12 +80,10 @@
         }
     }
 
-    public static PyString __doc__writerow = Py.newString(
-            "writerow(sequence)\n" +
-            "\n" +
-            "Construct and write a CSV record from a sequence of fields.  Non-string\n" +
-            "elements will be converted to string."
-            );
+    public static PyString __doc__writerow = Py.newString(//
+            "writerow(sequence)\n\n"
+            + "Construct and write a CSV record from a sequence of fields.  Non-string\n"
+            + "elements will be converted to string.");
 
     public boolean writerow(PyObject seq) {
         return writer_writerow(seq);
@@ -134,14 +130,17 @@
                     quoted = false;
             }
 
-            if (field instanceof PyString) {
+            if (field instanceof PyUnicode) {
+                // Unicode fields get the default encoding (must yield U16 bytes).
+                append_ok = join_append(((PyString) field).encode(), len == 1);
+            } else if (field instanceof PyString) {
+                // Not unicode, so must be U16 bytes.
                 append_ok = join_append(field.toString(), len == 1);
             } else if (field == Py.None) {
                 append_ok = join_append("", len == 1);
             } else {
                 PyObject str;
-                //XXX: in 3.x this check can go away and we can just always use
-                //     __str__
+                // XXX: in 3.x this check can go away and we can just always use __str__
                 if (field.getClass() == PyFloat.class) {
                     str = field.__repr__();
                 } else {
@@ -195,9 +194,9 @@
     }
 
     /**
-     * This method behaves differently depending on the value of copy_phase: if copy_phase
-     * is false, then the method determines the new record length. If copy_phase is true
-     * then the new field is appended to the record.
+     * This method behaves differently depending on the value of copy_phase: if copy_phase is false,
+     * then the method determines the new record length. If copy_phase is true then the new field is
+     * appended to the record.
      */
     private int join_append_data(String field, boolean quote_empty, boolean copy_phase) {
         int i;
@@ -225,7 +224,7 @@
                 break;
             }
             if (c == dialect.delimiter || c == dialect.escapechar || c == dialect.quotechar
-                || dialect.lineterminator.indexOf(c) > -1) {
+                    || dialect.lineterminator.indexOf(c) > -1) {
                 if (dialect.quoting == QuoteStyle.QUOTE_NONE) {
                     want_escape = true;
                 } else {
@@ -282,7 +281,6 @@
         rec_len++;
     }
 
-
     /* Traverseproc implementation */
     @Override
     public int traverse(Visitproc visit, Object arg) {

-- 
Repository URL: https://hg.python.org/jython


More information about the Jython-checkins mailing list