summaryrefslogtreecommitdiff
path: root/Lib/xml
diff options
context:
space:
mode:
authorEli Bendersky <eliben@gmail.com>2012-07-15 06:02:22 +0300
committerEli Bendersky <eliben@gmail.com>2012-07-15 06:02:22 +0300
commitd712b71adf6e52f6834dd2c1a7ab5b97dc6017fd (patch)
tree8194611fbf057e1219151f8541e1c4f7af4016ec /Lib/xml
parent7135b18f7f0e88d734cf18fec37baba885687df5 (diff)
downloadcpython-d712b71adf6e52f6834dd2c1a7ab5b97dc6017fd.tar.gz
Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka, with some minor fixes by me
Diffstat (limited to 'Lib/xml')
-rw-r--r--Lib/xml/etree/ElementTree.py138
1 files changed, 82 insertions, 56 deletions
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index 61fe155035..10bf84992e 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -100,6 +100,8 @@ VERSION = "1.3.0"
import sys
import re
import warnings
+import io
+import contextlib
from . import ElementPath
@@ -792,59 +794,38 @@ class ElementTree:
# None for only if not US-ASCII or UTF-8 or Unicode. None is default.
def write(self, file_or_filename,
- # keyword arguments
encoding=None,
xml_declaration=None,
default_namespace=None,
method=None):
- # assert self._root is not None
if not method:
method = "xml"
elif method not in _serialize:
- # FIXME: raise an ImportError for c14n if ElementC14N is missing?
raise ValueError("unknown method %r" % method)
if not encoding:
if method == "c14n":
encoding = "utf-8"
else:
encoding = "us-ascii"
- elif encoding == str: # lxml.etree compatibility.
- encoding = "unicode"
else:
encoding = encoding.lower()
- if hasattr(file_or_filename, "write"):
- file = file_or_filename
- else:
- if encoding != "unicode":
- file = open(file_or_filename, "wb")
+ with _get_writer(file_or_filename, encoding) as write:
+ if method == "xml" and (xml_declaration or
+ (xml_declaration is None and
+ encoding not in ("utf-8", "us-ascii", "unicode"))):
+ declared_encoding = encoding
+ if encoding == "unicode":
+ # Retrieve the default encoding for the xml declaration
+ import locale
+ declared_encoding = locale.getpreferredencoding()
+ write("<?xml version='1.0' encoding='%s'?>\n" % (
+ declared_encoding,))
+ if method == "text":
+ _serialize_text(write, self._root)
else:
- file = open(file_or_filename, "w")
- if encoding != "unicode":
- def write(text):
- try:
- return file.write(text.encode(encoding,
- "xmlcharrefreplace"))
- except (TypeError, AttributeError):
- _raise_serialization_error(text)
- else:
- write = file.write
- if method == "xml" and (xml_declaration or
- (xml_declaration is None and
- encoding not in ("utf-8", "us-ascii", "unicode"))):
- declared_encoding = encoding
- if encoding == "unicode":
- # Retrieve the default encoding for the xml declaration
- import locale
- declared_encoding = locale.getpreferredencoding()
- write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
- if method == "text":
- _serialize_text(write, self._root)
- else:
- qnames, namespaces = _namespaces(self._root, default_namespace)
- serialize = _serialize[method]
- serialize(write, self._root, qnames, namespaces)
- if file_or_filename is not file:
- file.close()
+ qnames, namespaces = _namespaces(self._root, default_namespace)
+ serialize = _serialize[method]
+ serialize(write, self._root, qnames, namespaces)
def write_c14n(self, file):
# lxml.etree compatibility. use output method instead
@@ -853,6 +834,58 @@ class ElementTree:
# --------------------------------------------------------------------
# serialization support
+@contextlib.contextmanager
+def _get_writer(file_or_filename, encoding):
+ # returns text write method and release all resourses after using
+ try:
+ write = file_or_filename.write
+ except AttributeError:
+ # file_or_filename is a file name
+ if encoding == "unicode":
+ file = open(file_or_filename, "w")
+ else:
+ file = open(file_or_filename, "w", encoding=encoding,
+ errors="xmlcharrefreplace")
+ with file:
+ yield file.write
+ else:
+ # file_or_filename is a file-like object
+ # encoding determines if it is a text or binary writer
+ if encoding == "unicode":
+ # use a text writer as is
+ yield write
+ else:
+ # wrap a binary writer with TextIOWrapper
+ with contextlib.ExitStack() as stack:
+ if isinstance(file_or_filename, io.BufferedIOBase):
+ file = file_or_filename
+ elif isinstance(file_or_filename, io.RawIOBase):
+ file = io.BufferedWriter(file_or_filename)
+ # Keep the original file open when the BufferedWriter is
+ # destroyed
+ stack.callback(file.detach)
+ else:
+ # This is to handle passed objects that aren't in the
+ # IOBase hierarchy, but just have a write method
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
+ file.write = write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ file.seekable = file_or_filename.seekable
+ file.tell = file_or_filename.tell
+ except AttributeError:
+ pass
+ file = io.TextIOWrapper(file,
+ encoding=encoding,
+ errors="xmlcharrefreplace",
+ newline="\n")
+ # Keep the original file open when the TextIOWrapper is
+ # destroyed
+ stack.callback(file.detach)
+ yield file.write
+
def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
@@ -1134,22 +1167,13 @@ def _escape_attrib_html(text):
# @defreturn string
def tostring(element, encoding=None, method=None):
- class dummy:
- pass
- data = []
- file = dummy()
- file.write = data.append
- ElementTree(element).write(file, encoding, method=method)
- if encoding in (str, "unicode"):
- return "".join(data)
- else:
- return b"".join(data)
+ stream = io.StringIO() if encoding == 'unicode' else io.BytesIO()
+ ElementTree(element).write(stream, encoding, method=method)
+ return stream.getvalue()
##
# Generates a string representation of an XML element, including all
-# subelements. If encoding is False, the string is returned as a
-# sequence of string fragments; otherwise it is a sequence of
-# bytestrings.
+# subelements.
#
# @param element An Element instance.
# @keyparam encoding Optional output encoding (default is US-ASCII).
@@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None):
# @since 1.3
def tostringlist(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
- file.write = data.append
- ElementTree(element).write(file, encoding, method=method)
- # FIXME: merge small fragments into larger parts
+ class DataStream(io.BufferedIOBase):
+ def writable(self):
+ return True
+
+ def write(self, b):
+ data.append(b)
+
+ ElementTree(element).write(DataStream(), encoding, method=method)
return data
##