From 4acf5925f22cc8fb8622d95065c57dff663e8483 Mon Sep 17 00:00:00 2001 From: Ivan Kanakarakis Date: Thu, 9 Sep 2021 01:44:11 +0300 Subject: Keep unknown metadata extensions Signed-off-by: Ivan Kanakarakis --- src/saml2/__init__.py | 24 ++++++++++++++++-------- src/saml2/mdie.py | 26 +++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index 6c11e200..6588a741 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -981,7 +981,7 @@ def extension_element_to_element(extension_element, translation_functions, return None -def extension_elements_to_elements(extension_elements, schemas): +def extension_elements_to_elements(extension_elements, schemas, keep_unmatched=False): """ Create a list of elements each one matching one of the given extension elements. This is of course dependent on the access to schemas that describe the extension elements. @@ -989,6 +989,8 @@ def extension_elements_to_elements(extension_elements, schemas): :param extension_elements: The list of extension elements :param schemas: Imported Python modules that represent the different known schemas used for the extension elements + :param keep_unmatched: Whether to keep extension elements that did not match any + schemas :return: A list of elements, representing the set of extension elements that was possible to match against a Class in the given schemas. The elements returned are the native representation of the elements @@ -1004,13 +1006,19 @@ def extension_elements_to_elements(extension_elements, schemas): return res for extension_element in extension_elements: - for schema in schemas: - inst = extension_element_to_element(extension_element, - schema.ELEMENT_FROM_STRING, - schema.NAMESPACE) - if inst: - res.append(inst) - break + convert_results = ( + inst + for schema in schemas + for inst in [ + extension_element_to_element( + extension_element, schema.ELEMENT_FROM_STRING, schema.NAMESPACE + ) + ] + if inst + ) + result = next(convert_results, extension_element if keep_unmatched else None) + if result: + res.append(result) return res diff --git a/src/saml2/mdie.py b/src/saml2/mdie.py index 89197315..1bbe3e8d 100644 --- a/src/saml2/mdie.py +++ b/src/saml2/mdie.py @@ -3,6 +3,7 @@ import six from saml2 import element_to_extension_element from saml2 import extension_elements_to_elements +from saml2 import ExtensionElement from saml2 import SamlBase from saml2 import md @@ -30,12 +31,20 @@ def _eval(val, onts, mdb_safe): return None else: return val - elif isinstance(val, dict) or isinstance(val, SamlBase): + elif ( + isinstance(val, dict) + or isinstance(val, SamlBase) + or isinstance(val, ExtensionElement) + ): return to_dict(val, onts, mdb_safe) elif isinstance(val, list): lv = [] for v in val: - if isinstance(v, dict) or isinstance(v, SamlBase): + if ( + isinstance(v, dict) + or isinstance(v, SamlBase) + or isinstance(v, ExtensionElement) + ): lv.append(to_dict(v, onts, mdb_safe)) else: lv.append(v) @@ -61,7 +70,9 @@ def to_dict(_dict, onts, mdb_safe=False): continue val = getattr(_dict, key) if key == "extension_elements": - _eel = extension_elements_to_elements(val, onts) + _eel = extension_elements_to_elements( + val, onts, keep_unmatched=True + ) _val = [_eval(_v, onts, mdb_safe) for _v in _eel] elif key == "extension_attributes": if mdb_safe: @@ -77,6 +88,15 @@ def to_dict(_dict, onts, mdb_safe=False): if mdb_safe: key = key.replace(".", "__") res[key] = _val + elif isinstance(_dict, ExtensionElement): + res = { + key: _val + for key, val in _dict.__dict__.items() + if val and key not in IMP_SKIP + for _val in [_eval(val, onts, mdb_safe)] + if _val + } + res["__class__"] = "%s&%s" % (_dict.namespace, _dict.tag) else: for key, val in _dict.items(): _val = _eval(val, onts, mdb_safe) -- cgit v1.2.1