summaryrefslogtreecommitdiff
path: root/src/saml2/authn.py
blob: 480d89965ca768f230d86c51e84bd16e97c7926a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import logging
import six
import time
from saml2 import SAMLError
import saml2.cryptography.symmetric
from saml2.httputil import Response
from saml2.httputil import make_cookie
from saml2.httputil import Redirect
from saml2.httputil import Unauthorized
from saml2.httputil import parse_cookie

from six.moves.urllib.parse import urlencode, parse_qs, urlsplit

__author__ = 'rolandh'

logger = logging.getLogger(__name__)


class AuthnFailure(SAMLError):
    pass


class EncodeError(SAMLError):
    pass


class UserAuthnMethod(object):
    def __init__(self, srv):
        self.srv = srv

    def __call__(self, *args, **kwargs):
        raise NotImplementedError

    def authenticated_as(self, **kwargs):
        raise NotImplementedError

    def verify(self, **kwargs):
        raise NotImplementedError


def is_equal(a, b):
    if len(a) != len(b):
        return False

    result = 0
    for x, y in zip(a, b):
        result |= x ^ y
    return result == 0


def url_encode_params(params=None):
    if not isinstance(params, dict):
        raise EncodeError("You must pass in a dictionary!")
    params_list = []
    for k, v in params.items():
        if isinstance(v, list):
            params_list.extend([(k, x) for x in v])
        else:
            params_list.append((k, v))
    return urlencode(params_list)


def create_return_url(base, query, **kwargs):
    """
    Add a query string plus extra parameters to a base URL which may contain
    a query part already.

    :param base: redirect_uri may contain a query part, no fragment allowed.
    :param query: Old query part as a string
    :param kwargs: extra query parameters
    :return:
    """
    part = urlsplit(base)
    if part.fragment:
        raise ValueError("Base URL contained parts it shouldn't")

    for key, values in parse_qs(query).items():
        if key in kwargs:
            if isinstance(kwargs[key], six.string_types):
                kwargs[key] = [kwargs[key]]
            kwargs[key].extend(values)
        else:
            kwargs[key] = values

    if part.query:
        for key, values in parse_qs(part.query).items():
            if key in kwargs:
                if isinstance(kwargs[key], six.string_types):
                    kwargs[key] = [kwargs[key]]
                kwargs[key].extend(values)
            else:
                kwargs[key] = values

        _pre = base.split("?")[0]
    else:
        _pre = base

    logger.debug("kwargs: %s" % kwargs)

    return "%s?%s" % (_pre, url_encode_params(kwargs))


class UsernamePasswordMako(UserAuthnMethod):
    """Do user authentication using the normal username password form
    using Mako as template system"""
    cookie_name = "userpassmako"

    def __init__(self, srv, mako_template, template_lookup, pwd, return_to):
        """
        :param srv: The server instance
        :param mako_template: Which Mako template to use
        :param pwd: Username/password dictionary like database
        :param return_to: Where to send the user after authentication
        :return:
        """
        UserAuthnMethod.__init__(self, srv)
        self.mako_template = mako_template
        self.template_lookup = template_lookup
        self.passwd = pwd
        self.return_to = return_to
        self.active = {}
        self.query_param = "upm_answer"
        self.symmetric = saml2.cryptography.symmetric.Default(srv.symkey)

    def __call__(self, cookie=None, policy_url=None, logo_url=None,
                 query="", **kwargs):
        """
        Put up the login form
        """
        if cookie:
            headers = [cookie]
        else:
            headers = []

        resp = Response(headers=headers)

        argv = {"login": "",
                "password": "",
                "action": "verify",
                "policy_url": policy_url,
                "logo_url": logo_url,
                "query": query}
        logger.info("do_authentication argv: %s" % argv)
        mte = self.template_lookup.get_template(self.mako_template)
        resp.message = mte.render(**argv)
        return resp

    def _verify(self, pwd, user):
        if not is_equal(pwd, self.passwd[user]):
            raise ValueError("Wrong password")

    def verify(self, request, **kwargs):
        """
        Verifies that the given username and password was correct
        :param request: Either the query part of a URL a urlencoded
            body of a HTTP message or a parse such.
        :param kwargs: Catch whatever else is sent.
        :return: redirect back to where ever the base applications
            wants the user after authentication.
        """

        # logger.debug("verify(%s)" % request)
        if isinstance(request, six.string_types):
            _dict = parse_qs(request)
        elif isinstance(request, dict):
            _dict = request
        else:
            raise ValueError("Wrong type of input")

        # verify username and password
        try:
            self._verify(_dict["password"][0], _dict["login"][0])
            timestamp = str(int(time.mktime(time.gmtime())))
            msg = "::".join([_dict["login"][0], timestamp])
            info = self.symmetric.encrypt(msg.encode())
            self.active[info] = timestamp
            cookie = make_cookie(self.cookie_name, info, self.srv.seed)
            return_to = create_return_url(self.return_to, _dict["query"][0],
                                          **{self.query_param: "true"})
            resp = Redirect(return_to, headers=[cookie])
        except (ValueError, KeyError):
            resp = Unauthorized("Unknown user or wrong password")

        return resp

    def authenticated_as(self, cookie=None, **kwargs):
        if cookie is None:
            return None
        else:
            logger.debug("kwargs: %s" % kwargs)
            try:
                info, timestamp = parse_cookie(self.cookie_name,
                                               self.srv.seed, cookie)
                if self.active[info] == timestamp:
                    msg = self.symmetric.decrypt(info).decode()
                    uid, _ts = msg.split("::")
                    if timestamp == _ts:
                        return {"uid": uid}
            except Exception:
                pass

        return None

    def done(self, areq):
        try:
            _ = areq[self.query_param]
            return False
        except KeyError:
            return True


class SocialService(UserAuthnMethod):
    def __init__(self, social):
        UserAuthnMethod.__init__(self, None)
        self.social = social

    def __call__(self, server_env, cookie=None, sid="", query="", **kwargs):
        return self.social.begin(server_env, cookie, sid, query)

    def callback(self, server_env, cookie=None, sid="", query="", **kwargs):
        return self.social.callback(server_env, cookie, sid, query, **kwargs)


class AuthnMethodChooser(object):
    def __init__(self, methods=None):
        self.methods = methods

    def __call__(self, **kwargs):
        if not self.methods:
            raise SAMLError("No authentication methods defined")
        elif len(self.methods) == 1:
            return self.methods[0]
        else:
            pass  # TODO

try:
    import ldap


    class LDAPAuthn(UsernamePasswordMako):
        def __init__(self, srv, ldapsrv, return_to,
                     dn_pattern, mako_template, template_lookup):
            """
            :param srv: The server instance
            :param ldapsrv: Which LDAP server to us
            :param return_to: Where to send the user after authentication
            :return:
            """
            UsernamePasswordMako.__init__(self, srv, mako_template, template_lookup,
                                          None, return_to)

            self.ldap = ldap.initialize(ldapsrv)
            self.ldap.protocol_version = 3
            self.ldap.set_option(ldap.OPT_REFERRALS, 0)
            self.dn_pattern = dn_pattern

        def _verify(self, pwd, user):
            """
            Verifies the username and password agains a LDAP server
            :param pwd: The password
            :param user: The username
            :return: AssertionError if the LDAP verification failed.
            """
            _dn = self.dn_pattern % user
            try:
                self.ldap.simple_bind_s(_dn, pwd)
            except Exception:
                raise AssertionError()
except ImportError:
    class LDAPAuthn(UserAuthnMethod):
        pass