[Python-checkins] cpython: Issue #23171: csv.Writer.writerow() now supports arbitrary iterables.

serhiy.storchaka python-checkins at python.org
Mon Mar 30 08:22:02 CEST 2015


https://hg.python.org/cpython/rev/cf5b62036445
changeset:   95275:cf5b62036445
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Mon Mar 30 09:09:54 2015 +0300
summary:
  Issue #23171: csv.Writer.writerow() now supports arbitrary iterables.

files:
  Doc/library/csv.rst  |   4 +-
  Lib/csv.py           |   7 +--
  Lib/test/test_csv.py |   8 +++
  Misc/NEWS            |   2 +
  Modules/_csv.c       |  79 +++++++++++++++----------------
  5 files changed, 54 insertions(+), 46 deletions(-)


diff --git a/Doc/library/csv.rst b/Doc/library/csv.rst
--- a/Doc/library/csv.rst
+++ b/Doc/library/csv.rst
@@ -419,7 +419,7 @@
 
 :class:`Writer` objects (:class:`DictWriter` instances and objects returned by
 the :func:`writer` function) have the following public methods.  A *row* must be
-a sequence of strings or numbers for :class:`Writer` objects and a dictionary
+an iterable of strings or numbers for :class:`Writer` objects and a dictionary
 mapping fieldnames to strings or numbers (by passing them through :func:`str`
 first) for :class:`DictWriter` objects.  Note that complex numbers are written
 out surrounded by parens. This may cause some problems for other programs which
@@ -431,6 +431,8 @@
    Write the *row* parameter to the writer's file object, formatted according to
    the current dialect.
 
+   .. versionchanged:: 3.5
+      Added support of arbitrary iterables.
 
 .. method:: csvwriter.writerows(rows)
 
diff --git a/Lib/csv.py b/Lib/csv.py
--- a/Lib/csv.py
+++ b/Lib/csv.py
@@ -147,16 +147,13 @@
             if wrong_fields:
                 raise ValueError("dict contains fields not in fieldnames: "
                                  + ", ".join([repr(x) for x in wrong_fields]))
-        return [rowdict.get(key, self.restval) for key in self.fieldnames]
+        return (rowdict.get(key, self.restval) for key in self.fieldnames)
 
     def writerow(self, rowdict):
         return self.writer.writerow(self._dict_to_list(rowdict))
 
     def writerows(self, rowdicts):
-        rows = []
-        for rowdict in rowdicts:
-            rows.append(self._dict_to_list(rowdict))
-        return self.writer.writerows(rows)
+        return self.writer.writerows(map(self._dict_to_list, rowdicts))
 
 # Guard Sniffer's type checking against builds that exclude complex()
 try:
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -186,6 +186,14 @@
         self._write_test(['a',1,'p,q'], 'a,1,p\\,q',
                          escapechar='\\', quoting = csv.QUOTE_NONE)
 
+    def test_write_iterable(self):
+        self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"')
+        self._write_test(iter(['a', 1, None]), 'a,1,')
+        self._write_test(iter([]), '')
+        self._write_test(iter([None]), '""')
+        self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE)
+        self._write_test(iter([None, None]), ',')
+
     def test_writerows(self):
         class BrokenFile:
             def write(self, buf):
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -56,6 +56,8 @@
 Library
 -------
 
+- Issue #23171: csv.Writer.writerow() now supports arbitrary iterables.
+
 - Issue #23745: The new email header parser now handles duplicate MIME
   parameter names without error, similar to how get_param behaves.
 
diff --git a/Modules/_csv.c b/Modules/_csv.c
--- a/Modules/_csv.c
+++ b/Modules/_csv.c
@@ -1009,7 +1009,7 @@
  */
 static Py_ssize_t
 join_append_data(WriterObj *self, unsigned int field_kind, void *field_data,
-                 Py_ssize_t field_len, int quote_empty, int *quoted,
+                 Py_ssize_t field_len, int *quoted,
                  int copy_phase)
 {
     DialectObj *dialect = self->dialect;
@@ -1071,18 +1071,6 @@
         ADDCH(c);
     }
 
-    /* If field is empty check if it needs to be quoted.
-     */
-    if (i == 0 && quote_empty) {
-        if (dialect->quoting == QUOTE_NONE) {
-            PyErr_Format(_csvstate_global->error_obj,
-                "single empty field record must be quoted");
-            return -1;
-        }
-        else
-            *quoted = 1;
-    }
-
     if (*quoted) {
         if (copy_phase)
             ADDCH(dialect->quotechar);
@@ -1126,7 +1114,7 @@
 }
 
 static int
-join_append(WriterObj *self, PyObject *field, int *quoted, int quote_empty)
+join_append(WriterObj *self, PyObject *field, int quoted)
 {
     unsigned int field_kind = -1;
     void *field_data = NULL;
@@ -1141,7 +1129,7 @@
         field_len = PyUnicode_GET_LENGTH(field);
     }
     rec_len = join_append_data(self, field_kind, field_data, field_len,
-                               quote_empty, quoted, 0);
+                               &quoted, 0);
     if (rec_len < 0)
         return 0;
 
@@ -1150,7 +1138,7 @@
         return 0;
 
     self->rec_len = join_append_data(self, field_kind, field_data, field_len,
-                                     quote_empty, quoted, 1);
+                                     &quoted, 1);
     self->num_fields++;
 
     return 1;
@@ -1181,37 +1169,30 @@
 }
 
 PyDoc_STRVAR(csv_writerow_doc,
-"writerow(sequence)\n"
+"writerow(iterable)\n"
 "\n"
-"Construct and write a CSV record from a sequence of fields.  Non-string\n"
+"Construct and write a CSV record from an iterable of fields.  Non-string\n"
 "elements will be converted to string.");
 
 static PyObject *
 csv_writerow(WriterObj *self, PyObject *seq)
 {
     DialectObj *dialect = self->dialect;
-    Py_ssize_t len, i;
-    PyObject *line, *result;
+    PyObject *iter, *field, *line, *result;
 
-    if (!PySequence_Check(seq))
-        return PyErr_Format(_csvstate_global->error_obj, "sequence expected");
-
-    len = PySequence_Length(seq);
-    if (len < 0)
-        return NULL;
+    iter = PyObject_GetIter(seq);
+    if (iter == NULL)
+        return PyErr_Format(_csvstate_global->error_obj,
+                            "iterable expected, not %.200s",
+                            seq->ob_type->tp_name);
 
     /* Join all fields in internal buffer.
      */
     join_reset(self);
-    for (i = 0; i < len; i++) {
-        PyObject *field;
+    while ((field = PyIter_Next(iter))) {
         int append_ok;
         int quoted;
 
-        field = PySequence_GetItem(seq, i);
-        if (field == NULL)
-            return NULL;
-
         switch (dialect->quoting) {
         case QUOTE_NONNUMERIC:
             quoted = !PyNumber_Check(field);
@@ -1225,11 +1206,11 @@
         }
 
         if (PyUnicode_Check(field)) {
-            append_ok = join_append(self, field, &quoted, len == 1);
+            append_ok = join_append(self, field, quoted);
             Py_DECREF(field);
         }
         else if (field == Py_None) {
-            append_ok = join_append(self, NULL, &quoted, len == 1);
+            append_ok = join_append(self, NULL, quoted);
             Py_DECREF(field);
         }
         else {
@@ -1237,19 +1218,37 @@
 
             str = PyObject_Str(field);
             Py_DECREF(field);
-            if (str == NULL)
+            if (str == NULL) {
+                Py_DECREF(iter);
                 return NULL;
-            append_ok = join_append(self, str, &quoted, len == 1);
+            }
+            append_ok = join_append(self, str, quoted);
             Py_DECREF(str);
         }
-        if (!append_ok)
+        if (!append_ok) {
+            Py_DECREF(iter);
+            return NULL;
+        }
+    }
+    Py_DECREF(iter);
+    if (PyErr_Occurred())
+        return NULL;
+
+    if (self->num_fields > 0 && self->rec_size == 0) {
+        if (dialect->quoting == QUOTE_NONE) {
+            PyErr_Format(_csvstate_global->error_obj,
+                "single empty field record must be quoted");
+            return NULL;
+        }
+        self->num_fields--;
+        if (!join_append(self, NULL, 1))
             return NULL;
     }
 
     /* Add line terminator.
      */
     if (!join_append_lineterminator(self))
-        return 0;
+        return NULL;
 
     line = PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND,
                                      (void *) self->rec, self->rec_len);
@@ -1261,9 +1260,9 @@
 }
 
 PyDoc_STRVAR(csv_writerows_doc,
-"writerows(sequence of sequences)\n"
+"writerows(iterable of iterables)\n"
 "\n"
-"Construct and write a series of sequences to a csv file.  Non-string\n"
+"Construct and write a series of iterables to a csv file.  Non-string\n"
 "elements will be converted to string.");
 
 static PyObject *

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list