[pypy-commit] pypy default: in-progress: writer

arigo noreply at buildbot.pypy.org
Mon Sep 24 18:10:28 CEST 2012


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r57503:c8258d1b52d9
Date: 2012-09-24 16:54 +0200
http://bitbucket.org/pypy/pypy/changeset/c8258d1b52d9/

Log:	in-progress: writer

diff --git a/pypy/module/_csv/__init__.py b/pypy/module/_csv/__init__.py
--- a/pypy/module/_csv/__init__.py
+++ b/pypy/module/_csv/__init__.py
@@ -82,4 +82,6 @@
 
         'reader': 'interp_reader.csv_reader',
         'field_size_limit': 'interp_reader.csv_field_size_limit',
+
+        'writer': 'interp_writer.csv_writer',
         }
diff --git a/pypy/module/_csv/interp_reader.py b/pypy/module/_csv/interp_reader.py
--- a/pypy/module/_csv/interp_reader.py
+++ b/pypy/module/_csv/interp_reader.py
@@ -219,6 +219,20 @@
                   w_skipinitialspace = NoneNotWrapped,
                   w_strict           = NoneNotWrapped,
                   ):
+    """
+    csv_reader = reader(iterable [, dialect='excel']
+                       [optional keyword args])
+    for row in csv_reader:
+        process(row)
+
+    The "iterable" argument can be any object that returns a line
+    of input for each iteration, such as a file object or a list.  The
+    optional \"dialect\" parameter is discussed below.  The function
+    also accepts optional keyword arguments which override settings
+    provided by the dialect.
+
+    The returned object is an iterator.  Each iteration returns a row
+    of the CSV file (which can span multiple input lines)"""
     w_iter = space.iter(w_iterator)
     dialect = _build_dialect(space, w_dialect, w_delimiter, w_doublequote,
                              w_escapechar, w_lineterminator, w_quotechar,
diff --git a/pypy/module/_csv/interp_writer.py b/pypy/module/_csv/interp_writer.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/_csv/interp_writer.py
@@ -0,0 +1,155 @@
+from pypy.rlib.rstring import StringBuilder
+from pypy.interpreter.baseobjspace import Wrappable
+from pypy.interpreter.error import OperationError
+from pypy.interpreter.gateway import NoneNotWrapped
+from pypy.interpreter.typedef import TypeDef, interp2app
+from pypy.interpreter.typedef import interp_attrproperty_w
+from pypy.module._csv.interp_csv import _build_dialect
+from pypy.module._csv.interp_csv import (QUOTE_MINIMAL, QUOTE_ALL,
+                                         QUOTE_NONNUMERIC, QUOTE_NONE)
+
+
+class W_Writer(Wrappable):
+
+    def __init__(self, space, dialect, w_fileobj):
+        self.space = space
+        self.dialect = dialect
+        self.w_filewrite = space.getattr(w_fileobj, space.wrap('write'))
+        # precompute this
+        special = dialect.delimiter + dialect.lineterminator
+        if dialect.escapechar != '\0': special += dialect.escapechar
+        if dialect.quotechar  != '\0': special += dialect.quotechar
+        self.special_characters = special
+
+    def error(self, msg):
+        space = self.space
+        w_module = space.getbuiltinmodule('_csv')
+        w_error = space.getattr(w_module, space.wrap('Error'))
+        raise OperationError(w_error, space.wrap(msg))
+    error._dont_inline_ = True
+
+    def writerow(self, w_fields):
+        space = self.space
+        fields_w = space.listview(w_fields)
+        dialect = self.dialect
+        rec = StringBuilder(80)
+        #
+        for field_index in range(len(fields_w)):
+            w_field = fields_w[field_index]
+            if space.is_w(w_field, space.w_None):
+                field = ""
+            elif space.isinstance_w(w_field, space.w_float):
+                field = space.str_w(space.repr(w_field))
+            else:
+                field = space.str_w(space.str(w_field))
+            #
+            if dialect.quoting == QUOTE_NONNUMERIC:
+                try:
+                    space.float_w(w_field)    # is it an int/long/float?
+                    quoted = False
+                except OperationError, e:
+                    if e.async(self):
+                        raise
+                    quoted = True
+            elif dialect.quoting == QUOTE_ALL:
+                quoted = True
+            elif dialect.quoting == QUOTE_MINIMAL:
+                # Find out if we really quoting
+                special_characters = self.special_characters
+                for c in field:
+                    if c in special_characters:
+                        if c != dialect.quotechar or dialect.doublequote:
+                            quoted = True
+                            break
+                else:
+                    quoted = False
+            else:
+                quoted = False
+
+            # If field is empty check if it needs to be quoted
+            if len(field) == 0 and len(fields_w) == 1:
+                if dialect.quoting == QUOTE_NONE:
+                    raise self.error("single empty field record "
+                                     "must be quoted")
+                quoted = True
+
+            # If this is not the first field we need a field separator
+            if field_index > 0:
+                rec.append(dialect.delimiter)
+
+            # Handle preceding quote
+            if quoted:
+                rec.append(dialect.quotechar)
+
+            # Copy field data
+            special_characters = self.special_characters
+            for c in field:
+                if c in special_characters:
+                    if dialect.quoting == QUOTE_NONE:
+                        want_escape = True
+                    else:
+                        want_escape = False
+                        if c == dialect.quotechar:
+                            if dialect.doublequote:
+                                rec.append(dialect.quotechar)
+                            else:
+                                want_escape = True
+                    if want_escape:
+                        if dialect.escapechar == '\0':
+                            raise self.error("need to escape, "
+                                             "but no escapechar set")
+                        rec.append(dialect.escapechar)
+                    else:
+                        assert quoted
+                # Copy field character into record buffer
+                rec.append(c)
+
+            # Handle final quote
+            if quoted:
+                rec.append(dialect.quotechar)
+
+        # Add line terminator
+        rec.append(dialect.lineterminator)
+
+        line = rec.build()
+        return space.call_function(self.w_filewrite, space.wrap(line))
+
+
+def csv_writer(space, w_fileobj, w_dialect=NoneNotWrapped,
+                  w_delimiter        = NoneNotWrapped,
+                  w_doublequote      = NoneNotWrapped,
+                  w_escapechar       = NoneNotWrapped,
+                  w_lineterminator   = NoneNotWrapped,
+                  w_quotechar        = NoneNotWrapped,
+                  w_quoting          = NoneNotWrapped,
+                  w_skipinitialspace = NoneNotWrapped,
+                  w_strict           = NoneNotWrapped,
+                  ):
+    """
+    csv_writer = csv.writer(fileobj [, dialect='excel']
+                            [optional keyword args])
+    for row in sequence:
+        csv_writer.writerow(row)
+
+    [or]
+
+    csv_writer = csv.writer(fileobj [, dialect='excel']
+                            [optional keyword args])
+    csv_writer.writerows(rows)
+
+    The \"fileobj\" argument can be any object that supports the file API."""
+    dialect = _build_dialect(space, w_dialect, w_delimiter, w_doublequote,
+                             w_escapechar, w_lineterminator, w_quotechar,
+                             w_quoting, w_skipinitialspace, w_strict)
+    return W_Writer(space, dialect, w_fileobj)
+
+W_Writer.typedef = TypeDef(
+        'writer',
+        __module__ = '_csv',
+        dialect = interp_attrproperty_w('dialect', W_Writer),
+        writerow = interp2app(W_Writer.writerow),
+        __doc__ = """CSV writer
+
+Writer objects are responsible for generating tabular data
+in CSV format from sequence input.""")
+W_Writer.typedef.acceptable_as_base_class = False
diff --git a/pypy/module/_csv/test/test_reader.py b/pypy/module/_csv/test/test_reader.py
--- a/pypy/module/_csv/test/test_reader.py
+++ b/pypy/module/_csv/test/test_reader.py
@@ -70,7 +70,7 @@
         import _csv as csv
         limit = csv.field_size_limit()
         try:
-            size = 50
+            size = 150
             bigstring = 'X' * size
             bigline = '%s,%s' % (bigstring, bigstring)
             self._read_test([bigline], [[bigstring, bigstring]])
diff --git a/pypy/module/_csv/test/test_writer.py b/pypy/module/_csv/test/test_writer.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/_csv/test/test_writer.py
@@ -0,0 +1,50 @@
+from pypy.conftest import gettestobjspace
+
+
+class AppTestWriter(object):
+    def setup_class(cls):
+        cls.space = gettestobjspace(usemodules=['_csv'])
+
+        w__write_test = cls.space.appexec([], r"""():
+            import _csv
+
+            class DummyFile(object):
+                def __init__(self):
+                    self._parts = []
+                    self.write = self._parts.append
+                def getvalue(self):
+                    return ''.join(self._parts)
+
+            def _write_test(fields, expect, **kwargs):
+                fileobj = DummyFile()
+                writer = _csv.writer(fileobj, **kwargs)
+                writer.writerow(fields)
+                result = fileobj.getvalue()
+                expect += writer.dialect.lineterminator
+                assert result == expect, 'result: %r\nexpect: %r' % (
+                    result, expect)
+            return _write_test
+        """)
+        if type(w__write_test) is type(lambda:0):
+            w__write_test = staticmethod(w__write_test)
+        cls.w__write_test = w__write_test
+
+    def test_write_arg_valid(self):
+        import _csv as csv
+        raises(TypeError, self._write_test, None, '')    # xxx different API!
+        self._write_test((), '')
+        self._write_test([None], '""')
+        raises(csv.Error, self._write_test,
+                          [None], None, quoting = csv.QUOTE_NONE)
+        # Check that exceptions are passed up the chain
+        class BadList:
+            def __len__(self):
+                return 10;
+            def __getitem__(self, i):
+                if i > 2:
+                    raise IOError
+        raises(IOError, self._write_test, BadList(), '')
+        class BadItem:
+            def __str__(self):
+                raise IOError
+        raises(IOError, self._write_test, [BadItem()], '')


More information about the pypy-commit mailing list