diff options
-rw-r--r-- | src/lxml/serializer.pxi | 60 | ||||
-rw-r--r-- | src/lxml/tests/test_incremental_xmlfile.py | 38 |
2 files changed, 93 insertions, 5 deletions
diff --git a/src/lxml/serializer.pxi b/src/lxml/serializer.pxi index 8cee18d8..4ef53bc9 100644 --- a/src/lxml/serializer.pxi +++ b/src/lxml/serializer.pxi @@ -1014,10 +1014,21 @@ cdef class _IncrementalFileWriter: tree.xmlOutputBufferFlush(self._c_out) self._handle_error(self._c_out.error) - def element(self, tag, attrib=None, nsmap=None, **_extra): - """element(self, tag, attrib=None, nsmap=None, **_extra) + def method(self, method): + """method(self, method) + + Returns a context manager that overrides and restores the output method. + method is one of (None, 'xml', 'html') where None means 'xml'. + """ + assert self._c_out is not NULL + c_method = self._method if method is None else _findOutputMethod(method) + return _MethodChanger(self, c_method) + + def element(self, tag, attrib=None, nsmap=None, method=None, **_extra): + """element(self, tag, attrib=None, nsmap=None, method, **_extra) Returns a context manager that writes an opening and closing tag. + method is one of (None, 'xml', 'html') where None means 'xml'. """ assert self._c_out is not NULL attributes = [] @@ -1038,7 +1049,10 @@ cdef class _IncrementalFileWriter: _prefixValidOrRaise(prefix) reversed_nsmap[_utf8(ns)] = prefix ns, name = _getNsTag(tag) - return _FileWriterElement(self, (ns, name, attributes, reversed_nsmap)) + + c_method = self._method if method is None else _findOutputMethod(method) + + return _FileWriterElement(self, (ns, name, attributes, reversed_nsmap), c_method) cdef _write_qname(self, bytes name, bytes prefix): if prefix: # empty bytes for no prefix (not None to allow sorting) @@ -1163,6 +1177,7 @@ cdef class _IncrementalFileWriter: ns in (None, b'http://www.w3.org/1999/xhtml') and name in (b'script', b'style')): tree.xmlOutputBufferWrite(self._c_out, len(content), _cstr(content)) + else: tree.xmlOutputBufferWriteEscape(self._c_out, _xcstr(content), NULL) @@ -1219,14 +1234,51 @@ cdef class _IncrementalFileWriter: @cython.freelist(8) cdef class _FileWriterElement: cdef object _element + cdef int _new_method + cdef int _old_method cdef _IncrementalFileWriter _writer - def __cinit__(self, _IncrementalFileWriter writer not None, element_config): + def __cinit__(self, _IncrementalFileWriter writer not None, element_config, int method): self._writer = writer self._element = element_config + self._new_method = method + self._old_method = writer._method def __enter__(self): + self._writer._method = self._new_method self._writer._write_start_element(self._element) def __exit__(self, exc_type, exc_val, exc_tb): self._writer._write_end_element(self._element) + self._writer._method = self._old_method + +@cython.final +@cython.internal +@cython.freelist(8) +cdef class _MethodChanger: + cdef int _new_method + cdef int _old_method + cdef bint _entered + cdef bint _exited + cdef _IncrementalFileWriter _writer + + def __cinit__(self, _IncrementalFileWriter writer not None, int method): + self._writer = writer + self._new_method = method + self._old_method = writer._method + self._entered = False + self._exited = False + + def __enter__(self): + if self._entered: + raise LxmlSyntaxError("Inconsistent enter action in context manager") + self._writer._method = self._new_method + self._entered = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._exited: + raise LxmlSyntaxError("Inconsistent exit action in context manager") + if self._writer._method != self._new_method: + raise LxmlSyntaxError("Method changed outside of context manager") + self._writer._method = self._old_method + self._exited = True diff --git a/src/lxml/tests/test_incremental_xmlfile.py b/src/lxml/tests/test_incremental_xmlfile.py index 6089e93a..c2f162b2 100644 --- a/src/lxml/tests/test_incremental_xmlfile.py +++ b/src/lxml/tests/test_incremental_xmlfile.py @@ -9,6 +9,8 @@ from __future__ import with_statement, absolute_import import unittest import tempfile, os, sys +from lxml.etree import LxmlSyntaxError + this_dir = os.path.dirname(__file__) if this_dir not in sys.path: sys.path.insert(0, this_dir) # needed for Py3 @@ -379,8 +381,31 @@ class HtmlFileTestCase(_XmlFileTestCaseBase): self.assertXml('<%s>' % tag) self._file = BytesIO() + def test_method_context_manager_misuse(self): + with etree.htmlfile(self._file) as xf: + with xf.element('foo'): + cm = xf.method('xml') + cm.__enter__() + + self.assertRaises(LxmlSyntaxError, cm.__enter__) + + cm2 = xf.method('xml') + cm2.__enter__() + cm2.__exit__(None, None, None) + + with self.assertRaises(LxmlSyntaxError): + cm2.__exit__(None, None, None) + + cm3 = xf.method('xml') + cm3.__enter__() + with xf.method('html'): + with self.assertRaises(LxmlSyntaxError): + cm3.__exit__(None, None, None) + def test_xml_mode_write_inside_html(self): - elt = etree.Element("foo", attrib={'selected': 'bar'}) + tag = 'foo' + attrib = {'selected': 'bar'} + elt = etree.Element(tag, attrib=attrib) with etree.htmlfile(self._file) as xf: with xf.element("root"): @@ -392,11 +417,22 @@ class HtmlFileTestCase(_XmlFileTestCaseBase): elt.text = "" xf.write(elt, method='xml') # 3 + with xf.element(tag, attrib=attrib, method='xml'): + pass # 4 + + xf.write(elt) # 5 + + with xf.method('xml'): + xf.write(elt) # 6 + self.assertXml( '<root>' '<foo selected></foo>' # 1 '<foo selected="bar"/>' # 2 '<foo selected="bar"></foo>' # 3 + '<foo selected="bar"></foo>' # 4 + '<foo selected></foo>' # 5 + '<foo selected="bar"></foo>' # 6 '</root>') self._file = BytesIO() |