Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka, with some minor fixes by me

This commit is contained in:
Eli Bendersky 2012-07-15 06:02:22 +03:00
parent 1191709b13
commit 00f402bfcb
3 changed files with 258 additions and 123 deletions

View File

@ -659,7 +659,6 @@ ElementTree Objects
should be added to the file. Use False for never, True for always, None should be added to the file. Use False for never, True for always, None
for only if not US-ASCII or UTF-8 or Unicode (default is None). *method* is for only if not US-ASCII or UTF-8 or Unicode (default is None). *method* is
either ``"xml"``, ``"html"`` or ``"text"`` (default is ``"xml"``). either ``"xml"``, ``"html"`` or ``"text"`` (default is ``"xml"``).
Returns an (optionally) encoded string.
This is the XML file that is going to be manipulated:: This is the XML file that is going to be manipulated::

View File

@ -21,7 +21,7 @@ import unittest
import weakref import weakref
from test import support from test import support
from test.support import findfile, import_fresh_module, gc_collect from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
pyET = None pyET = None
ET = None ET = None
@ -888,65 +888,6 @@ def check_encoding(encoding):
""" """
ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding) ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding)
def encoding():
r"""
Test encoding issues.
>>> elem = ET.Element("tag")
>>> elem.text = "abc"
>>> serialize(elem)
'<tag>abc</tag>'
>>> serialize(elem, encoding="utf-8")
b'<tag>abc</tag>'
>>> serialize(elem, encoding="us-ascii")
b'<tag>abc</tag>'
>>> serialize(elem, encoding="iso-8859-1")
b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>"
>>> elem.text = "<&\"\'>"
>>> serialize(elem)
'<tag>&lt;&amp;"\'&gt;</tag>'
>>> serialize(elem, encoding="utf-8")
b'<tag>&lt;&amp;"\'&gt;</tag>'
>>> serialize(elem, encoding="us-ascii") # cdata characters
b'<tag>&lt;&amp;"\'&gt;</tag>'
>>> serialize(elem, encoding="iso-8859-1")
b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag>&lt;&amp;"\'&gt;</tag>'
>>> elem.attrib["key"] = "<&\"\'>"
>>> elem.text = None
>>> serialize(elem)
'<tag key="&lt;&amp;&quot;\'&gt;" />'
>>> serialize(elem, encoding="utf-8")
b'<tag key="&lt;&amp;&quot;\'&gt;" />'
>>> serialize(elem, encoding="us-ascii")
b'<tag key="&lt;&amp;&quot;\'&gt;" />'
>>> serialize(elem, encoding="iso-8859-1")
b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="&lt;&amp;&quot;\'&gt;" />'
>>> elem.text = '\xe5\xf6\xf6<>'
>>> elem.attrib.clear()
>>> serialize(elem)
'<tag>\xe5\xf6\xf6&lt;&gt;</tag>'
>>> serialize(elem, encoding="utf-8")
b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>'
>>> serialize(elem, encoding="us-ascii")
b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>'
>>> serialize(elem, encoding="iso-8859-1")
b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6&lt;&gt;</tag>"
>>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
>>> elem.text = None
>>> serialize(elem)
'<tag key="\xe5\xf6\xf6&lt;&gt;" />'
>>> serialize(elem, encoding="utf-8")
b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />'
>>> serialize(elem, encoding="us-ascii")
b'<tag key="&#229;&#246;&#246;&lt;&gt;" />'
>>> serialize(elem, encoding="iso-8859-1")
b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6&lt;&gt;" />'
"""
def methods(): def methods():
r""" r"""
Test serialization methods. Test serialization methods.
@ -2166,16 +2107,185 @@ class ElementSlicingTest(unittest.TestCase):
self.assertEqual(self._subelem_tags(e), ['a1']) self.assertEqual(self._subelem_tags(e), ['a1'])
class StringIOTest(unittest.TestCase): class IOTest(unittest.TestCase):
def tearDown(self):
unlink(TESTFN)
def test_encoding(self):
# Test encoding issues.
elem = ET.Element("tag")
elem.text = "abc"
self.assertEqual(serialize(elem), '<tag>abc</tag>')
self.assertEqual(serialize(elem, encoding="utf-8"),
b'<tag>abc</tag>')
self.assertEqual(serialize(elem, encoding="us-ascii"),
b'<tag>abc</tag>')
for enc in ("iso-8859-1", "utf-16", "utf-32"):
self.assertEqual(serialize(elem, encoding=enc),
("<?xml version='1.0' encoding='%s'?>\n"
"<tag>abc</tag>" % enc).encode(enc))
elem = ET.Element("tag")
elem.text = "<&\"\'>"
self.assertEqual(serialize(elem), '<tag>&lt;&amp;"\'&gt;</tag>')
self.assertEqual(serialize(elem, encoding="utf-8"),
b'<tag>&lt;&amp;"\'&gt;</tag>')
self.assertEqual(serialize(elem, encoding="us-ascii"),
b'<tag>&lt;&amp;"\'&gt;</tag>')
for enc in ("iso-8859-1", "utf-16", "utf-32"):
self.assertEqual(serialize(elem, encoding=enc),
("<?xml version='1.0' encoding='%s'?>\n"
"<tag>&lt;&amp;\"'&gt;</tag>" % enc).encode(enc))
elem = ET.Element("tag")
elem.attrib["key"] = "<&\"\'>"
self.assertEqual(serialize(elem), '<tag key="&lt;&amp;&quot;\'&gt;" />')
self.assertEqual(serialize(elem, encoding="utf-8"),
b'<tag key="&lt;&amp;&quot;\'&gt;" />')
self.assertEqual(serialize(elem, encoding="us-ascii"),
b'<tag key="&lt;&amp;&quot;\'&gt;" />')
for enc in ("iso-8859-1", "utf-16", "utf-32"):
self.assertEqual(serialize(elem, encoding=enc),
("<?xml version='1.0' encoding='%s'?>\n"
"<tag key=\"&lt;&amp;&quot;'&gt;\" />" % enc).encode(enc))
elem = ET.Element("tag")
elem.text = '\xe5\xf6\xf6<>'
self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6&lt;&gt;</tag>')
self.assertEqual(serialize(elem, encoding="utf-8"),
b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>')
self.assertEqual(serialize(elem, encoding="us-ascii"),
b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>')
for enc in ("iso-8859-1", "utf-16", "utf-32"):
self.assertEqual(serialize(elem, encoding=enc),
("<?xml version='1.0' encoding='%s'?>\n"
"<tag>åöö&lt;&gt;</tag>" % enc).encode(enc))
elem = ET.Element("tag")
elem.attrib["key"] = '\xe5\xf6\xf6<>'
self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6&lt;&gt;" />')
self.assertEqual(serialize(elem, encoding="utf-8"),
b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />')
self.assertEqual(serialize(elem, encoding="us-ascii"),
b'<tag key="&#229;&#246;&#246;&lt;&gt;" />')
for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
self.assertEqual(serialize(elem, encoding=enc),
("<?xml version='1.0' encoding='%s'?>\n"
"<tag key=\"åöö&lt;&gt;\" />" % enc).encode(enc))
def test_write_to_filename(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree.write(TESTFN)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
def test_write_to_text_file(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
with open(TESTFN, 'w', encoding='utf-8') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
def test_write_to_binary_file(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
with open(TESTFN, 'wb') as f:
tree.write(f)
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
def test_write_to_binary_file_with_bom(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
# test BOM writing to buffered file
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-16')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
# test BOM writing to non-buffered file
with open(TESTFN, 'wb', buffering=0) as f:
tree.write(f, encoding='utf-16')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
def test_read_from_stringio(self): def test_read_from_stringio(self):
tree = ET.ElementTree() tree = ET.ElementTree()
stream = io.StringIO() stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
stream.write('''<?xml version="1.0"?><site></site>''')
stream.seek(0)
tree.parse(stream) tree.parse(stream)
self.assertEqual(tree.getroot().tag, 'site') self.assertEqual(tree.getroot().tag, 'site')
def test_write_to_stringio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
stream = io.StringIO()
tree.write(stream, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
def test_read_from_bytesio(self):
tree = ET.ElementTree()
raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
tree.parse(raw)
self.assertEqual(tree.getroot().tag, 'site')
def test_write_to_bytesio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
raw = io.BytesIO()
tree.write(raw)
self.assertEqual(raw.getvalue(), b'''<site />''')
class dummy:
pass
def test_read_from_user_text_reader(self):
stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
reader = self.dummy()
reader.read = stream.read
tree = ET.ElementTree()
tree.parse(reader)
self.assertEqual(tree.getroot().tag, 'site')
def test_write_to_user_text_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
stream = io.StringIO()
writer = self.dummy()
writer.write = stream.write
tree.write(writer, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
def test_read_from_user_binary_reader(self):
raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
reader = self.dummy()
reader.read = raw.read
tree = ET.ElementTree()
tree.parse(reader)
self.assertEqual(tree.getroot().tag, 'site')
tree = ET.ElementTree()
def test_write_to_user_binary_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
raw = io.BytesIO()
writer = self.dummy()
writer.write = raw.write
tree.write(writer)
self.assertEqual(raw.getvalue(), b'''<site />''')
def test_write_to_user_binary_writer_with_bom(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
raw = io.BytesIO()
writer = self.dummy()
writer.write = raw.write
writer.seekable = lambda: True
writer.tell = raw.tell
tree.write(writer, encoding="utf-16")
self.assertEqual(raw.getvalue(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
class ParseErrorTest(unittest.TestCase): class ParseErrorTest(unittest.TestCase):
def test_subclass(self): def test_subclass(self):
@ -2299,7 +2409,7 @@ def test_main(module=None):
test_classes = [ test_classes = [
ElementSlicingTest, ElementSlicingTest,
BasicElementTest, BasicElementTest,
StringIOTest, IOTest,
ParseErrorTest, ParseErrorTest,
XincludeTest, XincludeTest,
ElementTreeTest, ElementTreeTest,

View File

@ -100,6 +100,8 @@ VERSION = "1.3.0"
import sys import sys
import re import re
import warnings import warnings
import io
import contextlib
from . import ElementPath from . import ElementPath
@ -792,59 +794,38 @@ class ElementTree:
# None for only if not US-ASCII or UTF-8 or Unicode. None is default. # None for only if not US-ASCII or UTF-8 or Unicode. None is default.
def write(self, file_or_filename, def write(self, file_or_filename,
# keyword arguments
encoding=None, encoding=None,
xml_declaration=None, xml_declaration=None,
default_namespace=None, default_namespace=None,
method=None): method=None):
# assert self._root is not None
if not method: if not method:
method = "xml" method = "xml"
elif method not in _serialize: elif method not in _serialize:
# FIXME: raise an ImportError for c14n if ElementC14N is missing?
raise ValueError("unknown method %r" % method) raise ValueError("unknown method %r" % method)
if not encoding: if not encoding:
if method == "c14n": if method == "c14n":
encoding = "utf-8" encoding = "utf-8"
else: else:
encoding = "us-ascii" encoding = "us-ascii"
elif encoding == str: # lxml.etree compatibility.
encoding = "unicode"
else: else:
encoding = encoding.lower() encoding = encoding.lower()
if hasattr(file_or_filename, "write"): with _get_writer(file_or_filename, encoding) as write:
file = file_or_filename if method == "xml" and (xml_declaration or
else: (xml_declaration is None and
if encoding != "unicode": encoding not in ("utf-8", "us-ascii", "unicode"))):
file = open(file_or_filename, "wb") declared_encoding = encoding
if encoding == "unicode":
# Retrieve the default encoding for the xml declaration
import locale
declared_encoding = locale.getpreferredencoding()
write("<?xml version='1.0' encoding='%s'?>\n" % (
declared_encoding,))
if method == "text":
_serialize_text(write, self._root)
else: else:
file = open(file_or_filename, "w") qnames, namespaces = _namespaces(self._root, default_namespace)
if encoding != "unicode": serialize = _serialize[method]
def write(text): serialize(write, self._root, qnames, namespaces)
try:
return file.write(text.encode(encoding,
"xmlcharrefreplace"))
except (TypeError, AttributeError):
_raise_serialization_error(text)
else:
write = file.write
if method == "xml" and (xml_declaration or
(xml_declaration is None and
encoding not in ("utf-8", "us-ascii", "unicode"))):
declared_encoding = encoding
if encoding == "unicode":
# Retrieve the default encoding for the xml declaration
import locale
declared_encoding = locale.getpreferredencoding()
write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
if method == "text":
_serialize_text(write, self._root)
else:
qnames, namespaces = _namespaces(self._root, default_namespace)
serialize = _serialize[method]
serialize(write, self._root, qnames, namespaces)
if file_or_filename is not file:
file.close()
def write_c14n(self, file): def write_c14n(self, file):
# lxml.etree compatibility. use output method instead # lxml.etree compatibility. use output method instead
@ -853,6 +834,58 @@ class ElementTree:
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# serialization support # serialization support
@contextlib.contextmanager
def _get_writer(file_or_filename, encoding):
# returns text write method and release all resourses after using
try:
write = file_or_filename.write
except AttributeError:
# file_or_filename is a file name
if encoding == "unicode":
file = open(file_or_filename, "w")
else:
file = open(file_or_filename, "w", encoding=encoding,
errors="xmlcharrefreplace")
with file:
yield file.write
else:
# file_or_filename is a file-like object
# encoding determines if it is a text or binary writer
if encoding == "unicode":
# use a text writer as is
yield write
else:
# wrap a binary writer with TextIOWrapper
with contextlib.ExitStack() as stack:
if isinstance(file_or_filename, io.BufferedIOBase):
file = file_or_filename
elif isinstance(file_or_filename, io.RawIOBase):
file = io.BufferedWriter(file_or_filename)
# Keep the original file open when the BufferedWriter is
# destroyed
stack.callback(file.detach)
else:
# This is to handle passed objects that aren't in the
# IOBase hierarchy, but just have a write method
file = io.BufferedIOBase()
file.writable = lambda: True
file.write = write
try:
# TextIOWrapper uses this methods to determine
# if BOM (for UTF-16, etc) should be added
file.seekable = file_or_filename.seekable
file.tell = file_or_filename.tell
except AttributeError:
pass
file = io.TextIOWrapper(file,
encoding=encoding,
errors="xmlcharrefreplace",
newline="\n")
# Keep the original file open when the TextIOWrapper is
# destroyed
stack.callback(file.detach)
yield file.write
def _namespaces(elem, default_namespace=None): def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree # identify namespaces used in this tree
@ -1134,22 +1167,13 @@ def _escape_attrib_html(text):
# @defreturn string # @defreturn string
def tostring(element, encoding=None, method=None): def tostring(element, encoding=None, method=None):
class dummy: stream = io.StringIO() if encoding == 'unicode' else io.BytesIO()
pass ElementTree(element).write(stream, encoding, method=method)
data = [] return stream.getvalue()
file = dummy()
file.write = data.append
ElementTree(element).write(file, encoding, method=method)
if encoding in (str, "unicode"):
return "".join(data)
else:
return b"".join(data)
## ##
# Generates a string representation of an XML element, including all # Generates a string representation of an XML element, including all
# subelements. If encoding is False, the string is returned as a # subelements.
# sequence of string fragments; otherwise it is a sequence of
# bytestrings.
# #
# @param element An Element instance. # @param element An Element instance.
# @keyparam encoding Optional output encoding (default is US-ASCII). # @keyparam encoding Optional output encoding (default is US-ASCII).
@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None):
# @since 1.3 # @since 1.3
def tostringlist(element, encoding=None, method=None): def tostringlist(element, encoding=None, method=None):
class dummy:
pass
data = [] data = []
file = dummy() class DataStream(io.BufferedIOBase):
file.write = data.append def writable(self):
ElementTree(element).write(file, encoding, method=method) return True
# FIXME: merge small fragments into larger parts
def write(self, b):
data.append(b)
ElementTree(element).write(DataStream(), encoding, method=method)
return data return data
## ##