diff options
Diffstat (limited to 'src/lxml/classlookup.pxi')
-rw-r--r-- | src/lxml/classlookup.pxi | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/src/lxml/classlookup.pxi b/src/lxml/classlookup.pxi index f1ac43c3..3b6e6d2b 100644 --- a/src/lxml/classlookup.pxi +++ b/src/lxml/classlookup.pxi @@ -186,6 +186,25 @@ cdef class EntityBase(_Entity): self._init() +cdef int _validateNodeClass(xmlNode* c_node, cls) except -1: + if c_node.type == tree.XML_ELEMENT_NODE: + expected = ElementBase + elif c_node.type == tree.XML_COMMENT_NODE: + expected = CommentBase + elif c_node.type == tree.XML_ENTITY_REF_NODE: + expected = EntityBase + elif c_node.type == tree.XML_PI_NODE: + expected = PIBase + else: + assert 0, u"Unknown node type: %s" % c_node.type + + if not (isinstance(cls, type) and issubclass(cls, expected)): + raise TypeError( + "result of class lookup must be subclass of %s, got %s" + % (type(expected), type(cls))) + return 0 + + ################################################################################ # Element class lookup @@ -366,9 +385,7 @@ cdef object _attribute_class_lookup(state, _Document doc, xmlNode* c_node): dict_result = python.PyDict_GetItem(lookup._class_mapping, value) if dict_result is not NULL: cls = <object>dict_result - if not isinstance(cls, type): - raise TypeError("class lookup must return class, got %s" - % type(cls)) + _validateNodeClass(c_node, cls) return cls return _callLookupFallback(lookup, doc, c_node) @@ -440,9 +457,7 @@ cdef object _custom_class_lookup(state, _Document doc, xmlNode* c_node): cls = lookup.lookup(element_type, doc, ns, name) if cls is not None: - if not isinstance(cls, type): - raise TypeError("class lookup must return class, got %s" - % type(cls)) + _validateNodeClass(c_node, cls) return cls return _callLookupFallback(lookup, doc, c_node) @@ -513,9 +528,7 @@ cdef object _python_class_lookup(state, _Document doc, tree.xmlNode* c_node): _freeReadOnlyProxies(proxy) if cls is not None: - if not isinstance(cls, type): - raise TypeError("class lookup must return class, got %s" - % type(cls)) + _validateNodeClass(c_node, cls) return cls return _callLookupFallback(lookup, doc, c_node) |