[issue1767933] Badly formed XML using etree and utf-16

Serhiy Storchaka report at bugs.python.org
Fri Jul 13 22:25:34 CEST 2012


Serhiy Storchaka <storchaka at gmail.com> added the comment:

Patch updated with some comments.

----------
Added file: http://bugs.python.org/file26377/etree_write_utf16_5.patch

_______________________________________
Python tracker <report at bugs.python.org>
<http://bugs.python.org/issue1767933>
_______________________________________
-------------- next part --------------
diff -r 677a9326b4d4 Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py	Mon Jul 09 18:16:11 2012 -0700
+++ b/Lib/test/test_xml_etree.py	Fri Jul 13 23:23:04 2012 +0300
@@ -21,7 +21,7 @@
 import weakref
 
 from test import support
-from test.support import findfile, import_fresh_module, gc_collect
+from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
 
 pyET = None
 ET = None
@@ -888,65 +888,6 @@
     """
     ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding)
 
-def encoding():
-    r"""
-    Test encoding issues.
-
-    >>> elem = ET.Element("tag")
-    >>> elem.text = "abc"
-    >>> serialize(elem)
-    '<tag>abc</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>abc</tag>'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag>abc</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>"
-
-    >>> elem.text = "<&\"\'>"
-    >>> serialize(elem)
-    '<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="us-ascii") # cdata characters
-    b'<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag>&lt;&amp;"\'&gt;</tag>'
-
-    >>> elem.attrib["key"] = "<&\"\'>"
-    >>> elem.text = None
-    >>> serialize(elem)
-    '<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="&lt;&amp;&quot;\'&gt;" />'
-
-    >>> elem.text = '\xe5\xf6\xf6<>'
-    >>> elem.attrib.clear()
-    >>> serialize(elem)
-    '<tag>\xe5\xf6\xf6&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6&lt;&gt;</tag>"
-
-    >>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
-    >>> elem.text = None
-    >>> serialize(elem)
-    '<tag key="\xe5\xf6\xf6&lt;&gt;" />'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag key="&#229;&#246;&#246;&lt;&gt;" />'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6&lt;&gt;" />'
-    """
-
 def methods():
     r"""
     Test serialization methods.
@@ -2166,16 +2107,185 @@
         self.assertEqual(self._subelem_tags(e), ['a1'])
 
 
-class StringIOTest(unittest.TestCase):
+class IOTest(unittest.TestCase):
+    def tearDown(self):
+        unlink(TESTFN)
+
+    def test_encoding(self):
+        # Test encoding issues.
+        elem = ET.Element("tag")
+        elem.text = "abc"
+        self.assertEqual(serialize(elem), '<tag>abc</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>abc</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>abc</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>abc</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.text = "<&\"\'>"
+        self.assertEqual(serialize(elem), '<tag>&lt;&amp;"\'&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>&lt;&amp;"\'&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>&lt;&amp;"\'&gt;</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>&lt;&amp;\"'&gt;</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.attrib["key"] = "<&\"\'>"
+        self.assertEqual(serialize(elem), '<tag key="&lt;&amp;&quot;\'&gt;" />')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag key=\"&lt;&amp;&quot;'&gt;\" />" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.text = '\xe5\xf6\xf6<>'
+        self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6&lt;&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>������&lt;&gt;</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.attrib["key"] = '\xe5\xf6\xf6<>'
+        self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6&lt;&gt;" />')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag key="&#229;&#246;&#246;&lt;&gt;" />')
+        for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag key=\"������&lt;&gt;\" />" % enc).encode(enc))
+
+    def test_write_to_filename(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        tree.write(TESTFN)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_text_file(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        with open(TESTFN, 'w', encoding='utf-8') as f:
+            tree.write(f, encoding='unicode')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_binary_file(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        with open(TESTFN, 'wb') as f:
+            tree.write(f)
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_binary_file_with_bom(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        # test BOM writing to buffered file
+        with open(TESTFN, 'wb') as f:
+            tree.write(f, encoding='utf-16')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(),
+                    '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                    '''<site />'''.encode("utf-16"))
+        # test BOM writing to non-buffered file
+        with open(TESTFN, 'wb', buffering=0) as f:
+            tree.write(f, encoding='utf-16')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(),
+                    '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                    '''<site />'''.encode("utf-16"))
+
     def test_read_from_stringio(self):
         tree = ET.ElementTree()
+        stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+        tree.parse(stream)
+        self.assertEqual(tree.getroot().tag, 'site')
+
+    def test_write_to_stringio(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
         stream = io.StringIO()
-        stream.write('''<?xml version="1.0"?><site></site>''')
-        stream.seek(0)
-        tree.parse(stream)
+        tree.write(stream, encoding='unicode')
+        self.assertEqual(stream.getvalue(), '''<site />''')
 
+    def test_read_from_bytesio(self):
+        tree = ET.ElementTree()
+        raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+        tree.parse(raw)
         self.assertEqual(tree.getroot().tag, 'site')
 
+    def test_write_to_bytesio(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        tree.write(raw)
+        self.assertEqual(raw.getvalue(), b'''<site />''')
+
+    class dummy:
+        pass
+
+    def test_read_from_user_text_reader(self):
+        stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+        reader = self.dummy()
+        reader.read = stream.read
+        tree = ET.ElementTree()
+        tree.parse(reader)
+        self.assertEqual(tree.getroot().tag, 'site')
+
+    def test_write_to_user_text_writer(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        stream = io.StringIO()
+        writer = self.dummy()
+        writer.write = stream.write
+        tree.write(writer, encoding='unicode')
+        self.assertEqual(stream.getvalue(), '''<site />''')
+
+    def test_read_from_user_binary_reader(self):
+        raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+        reader = self.dummy()
+        reader.read = raw.read
+        tree = ET.ElementTree()
+        tree.parse(reader)
+        self.assertEqual(tree.getroot().tag, 'site')
+        tree = ET.ElementTree()
+
+    def test_write_to_user_binary_writer(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        writer = self.dummy()
+        writer.write = raw.write
+        tree.write(writer)
+        self.assertEqual(raw.getvalue(), b'''<site />''')
+
+    def test_write_to_user_binary_writer_with_bom(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        writer = self.dummy()
+        writer.write = raw.write
+        writer.seekable = lambda: True
+        writer.tell = raw.tell
+        tree.write(writer, encoding="utf-16")
+        self.assertEqual(raw.getvalue(),
+                '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                '''<site />'''.encode("utf-16"))
+
 
 class ParseErrorTest(unittest.TestCase):
     def test_subclass(self):
@@ -2299,7 +2409,7 @@
     test_classes = [
         ElementSlicingTest,
         BasicElementTest,
-        StringIOTest,
+        IOTest,
         ParseErrorTest,
         XincludeTest,
         ElementTreeTest,
diff -r 677a9326b4d4 Lib/xml/etree/ElementTree.py
--- a/Lib/xml/etree/ElementTree.py	Mon Jul 09 18:16:11 2012 -0700
+++ b/Lib/xml/etree/ElementTree.py	Fri Jul 13 23:23:04 2012 +0300
@@ -100,6 +100,8 @@
 import sys
 import re
 import warnings
+import io
+import contextlib
 
 from . import ElementPath
 
@@ -812,39 +814,22 @@
             encoding = "unicode"
         else:
             encoding = encoding.lower()
-        if hasattr(file_or_filename, "write"):
-            file = file_or_filename
-        else:
-            if encoding != "unicode":
-                file = open(file_or_filename, "wb")
+        with _get_writer(file_or_filename, encoding) as write:
+            if method == "xml" and (xml_declaration or
+                    (xml_declaration is None and
+                     encoding not in ("utf-8", "us-ascii", "unicode"))):
+                declared_encoding = encoding
+                if encoding == "unicode":
+                    # Retrieve the default encoding for the xml declaration
+                    import locale
+                    declared_encoding = locale.getpreferredencoding()
+                write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
+            if method == "text":
+                _serialize_text(write, self._root)
             else:
-                file = open(file_or_filename, "w")
-        if encoding != "unicode":
-            def write(text):
-                try:
-                    return file.write(text.encode(encoding,
-                                                  "xmlcharrefreplace"))
-                except (TypeError, AttributeError):
-                    _raise_serialization_error(text)
-        else:
-            write = file.write
-        if method == "xml" and (xml_declaration or
-                (xml_declaration is None and
-                 encoding not in ("utf-8", "us-ascii", "unicode"))):
-            declared_encoding = encoding
-            if encoding == "unicode":
-                # Retrieve the default encoding for the xml declaration
-                import locale
-                declared_encoding = locale.getpreferredencoding()
-            write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
-        if method == "text":
-            _serialize_text(write, self._root)
-        else:
-            qnames, namespaces = _namespaces(self._root, default_namespace)
-            serialize = _serialize[method]
-            serialize(write, self._root, qnames, namespaces)
-        if file_or_filename is not file:
-            file.close()
+                qnames, namespaces = _namespaces(self._root, default_namespace)
+                serialize = _serialize[method]
+                serialize(write, self._root, qnames, namespaces)
 
     def write_c14n(self, file):
         # lxml.etree compatibility.  use output method instead
@@ -853,6 +838,54 @@
 # --------------------------------------------------------------------
 # serialization support
 
+ at contextlib.contextmanager
+def _get_writer(file_or_filename, encoding):
+    # returns text write method and release all resourses after using
+    try:
+        write = file_or_filename.write
+    except AttributeError:
+        # file_or_filename is a file name
+        if encoding == "unicode":
+            file = open(file_or_filename, "w")
+        else:
+            file = open(file_or_filename, "w", encoding=encoding,
+                        errors="xmlcharrefreplace")
+        with file:
+            yield file.write
+    else:
+        # file_or_filename is a file-like object
+        # encoding determines if it is a text or binary writer
+        if encoding == "unicode":
+            # use a text writer as is
+            yield write
+        else:
+            # wrap a binary writer with TextIOWrapper
+            with contextlib.ExitStack() as stack:
+                if isinstance(file_or_filename, io.BufferedIOBase):
+                    file = file_or_filename
+                elif isinstance(file_or_filename, io.RawIOBase):
+                    file = io.BufferedWriter(file_or_filename)
+                    # keep the original file open when the BufferedWriter is destroyed
+                    stack.callback(file.detach)
+                else:
+                    file = io.BufferedIOBase()
+                    file.writable = lambda: True
+                    file.write = write
+                    try:
+                        # TextIOWrapper uses this methods to determine
+                        # if BOM (for UTF-16, etc) should be added
+                        file.seekable = file_or_filename.seekable
+                        file.tell = file_or_filename.tell
+                    except AttributeError:
+                        pass
+                file = io.TextIOWrapper(file,
+                                        encoding=encoding,
+                                        errors="xmlcharrefreplace",
+                                        newline="\n")
+                # keep the original file open when the TextIOWrapper is destroyed
+                stack.callback(file.detach)
+                yield file.write
+
 def _namespaces(elem, default_namespace=None):
     # identify namespaces used in this tree
 
@@ -1134,10 +1167,9 @@
 # @defreturn string
 
 def tostring(element, encoding=None, method=None):
-    class dummy:
-        pass
     data = []
-    file = dummy()
+    file = io.BufferedIOBase()
+    file.writable = lambda: True
     file.write = data.append
     ElementTree(element).write(file, encoding, method=method)
     if encoding in (str, "unicode"):
@@ -1161,10 +1193,9 @@
 # @since 1.3
 
 def tostringlist(element, encoding=None, method=None):
-    class dummy:
-        pass
     data = []
-    file = dummy()
+    file = io.BufferedIOBase()
+    file.writable = lambda: True
     file.write = data.append
     ElementTree(element).write(file, encoding, method=method)
     # FIXME: merge small fragments into larger parts


More information about the Python-bugs-list mailing list