summaryrefslogtreecommitdiff
path: root/src/saml2/mdie.py
blob: 1bbe3e8d787eaaa09184e1aaf2804fc37810da91 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
import six

from saml2 import element_to_extension_element
from saml2 import extension_elements_to_elements
from saml2 import ExtensionElement
from saml2 import SamlBase
from saml2 import md

__author__ = 'rolandh'

"""
Functions used to import metadata from and export it to a pysaml2 format
"""

IMP_SKIP = ["_certs", "e_e_", "_extatt"]
EXP_SKIP = ["__class__"]


# From pysaml2 SAML2 metadata format to Python dictionary
def _eval(val, onts, mdb_safe):
    """
    Convert a value to a basic dict format
    :param val: The value
    :param onts: Schemas to be used in the conversion
    :return: The basic dictionary
    """
    if isinstance(val, six.string_types):
        val = val.strip()
        if not val:
            return None
        else:
            return val
    elif (
        isinstance(val, dict)
        or isinstance(val, SamlBase)
        or isinstance(val, ExtensionElement)
    ):
        return to_dict(val, onts, mdb_safe)
    elif isinstance(val, list):
        lv = []
        for v in val:
            if (
                isinstance(v, dict)
                or isinstance(v, SamlBase)
                or isinstance(v, ExtensionElement)
            ):
                lv.append(to_dict(v, onts, mdb_safe))
            else:
                lv.append(v)
        return lv
    return val


def to_dict(_dict, onts, mdb_safe=False):
    """
    Convert a pysaml2 SAML2 message class instance into a basic dictionary
    format.
    The export interface.

    :param _dict: The pysaml2 metadata instance
    :param onts: List of schemas to use for the conversion
    :return: The converted information
    """
    res = {}
    if isinstance(_dict, SamlBase):
        res["__class__"] = "%s&%s" % (_dict.c_namespace, _dict.c_tag)
        for key in _dict.keyswv():
            if key in IMP_SKIP:
                continue
            val = getattr(_dict, key)
            if key == "extension_elements":
                _eel = extension_elements_to_elements(
                    val, onts, keep_unmatched=True
                )
                _val = [_eval(_v, onts, mdb_safe) for _v in _eel]
            elif key == "extension_attributes":
                if mdb_safe:
                    _val = dict([(k.replace(".", "__"), v) for k, v in
                                 val.items()])
                    #_val = {k.replace(".", "__"): v for k, v in val.items()}
                else:
                    _val = val
            else:
                _val = _eval(val, onts, mdb_safe)

            if _val:
                if mdb_safe:
                    key = key.replace(".", "__")
                res[key] = _val
    elif isinstance(_dict, ExtensionElement):
        res = {
            key: _val
            for key, val in _dict.__dict__.items()
            if val and key not in IMP_SKIP
            for _val in [_eval(val, onts, mdb_safe)]
            if _val
        }
        res["__class__"] = "%s&%s" % (_dict.namespace, _dict.tag)
    else:
        for key, val in _dict.items():
            _val = _eval(val, onts, mdb_safe)
            if _val:
                if mdb_safe and "." in key:
                    key = key.replace(".", "__")
                res[key] = _val
    return res


# From Python dictionary to pysaml2 SAML2 metadata format

def _kwa(val, onts, mdb_safe=False):
    """
    Key word argument conversion

    :param val: A dictionary
    :param onts: dictionary with schemas to use in the conversion
        schema namespase is the key in the dictionary
    :return: A converted dictionary
    """
    if not mdb_safe:
        return dict([(k, from_dict(v, onts)) for k, v in val.items()
                     if k not in EXP_SKIP])
    else:
        _skip = ["_id"]
        _skip.extend(EXP_SKIP)
        return dict([(k.replace("__", "."), from_dict(v, onts)) for k, v in
                     val.items() if k not in _skip])


def from_dict(val, onts, mdb_safe=False):
    """
    Converts a dictionary into a pysaml2 object
    :param val: A dictionary
    :param onts: Dictionary of schemas to use in the conversion
    :return: The pysaml2 object instance
    """
    if isinstance(val, dict):
        if "__class__" in val:
            ns, typ = val["__class__"].split("&")
            cls = getattr(onts[ns], typ)
            if cls is md.Extensions:
                lv = []
                for key, ditems in val.items():
                    if key in EXP_SKIP:
                        continue
                    for item in ditems:
                        ns, typ = item["__class__"].split("&")
                        cls = getattr(onts[ns], typ)
                        kwargs = _kwa(item, onts, mdb_safe)
                        inst = cls(**kwargs)
                        lv.append(element_to_extension_element(inst))
                return lv
            else:
                kwargs = _kwa(val, onts, mdb_safe)
                inst = cls(**kwargs)
            return inst
        else:
            res = {}
            for key, v in val.items():
                if mdb_safe:
                    key = key.replace("__", ".")
                res[key] = from_dict(v, onts)
            return res
    elif isinstance(val, six.string_types):
        return val
    elif isinstance(val, list):
        return [from_dict(v, onts) for v in val]
    else:
        return val