summaryrefslogtreecommitdiff
path: root/src/saml2/httputil.py
blob: 7217b147191c2411cbe7e36145087b1480b727ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
import cgi
import hashlib
import hmac
from http.cookies import SimpleCookie
import logging
import time
from urllib.parse import parse_qs
from urllib.parse import quote

from saml2 import BINDING_HTTP_ARTIFACT
from saml2 import BINDING_HTTP_POST
from saml2 import BINDING_HTTP_REDIRECT
from saml2 import BINDING_SOAP
from saml2 import BINDING_URI
from saml2 import SAMLError
from saml2 import time_util


__author__ = "rohe0002"

logger = logging.getLogger(__name__)


class Response:
    _template = None
    _status = "200 OK"
    _content_type = "text/html"
    _mako_template = None
    _mako_lookup = None

    def __init__(self, message=None, **kwargs):
        self.status = kwargs.get("status", self._status)
        self.response = kwargs.get("response", self._response)
        self.template = kwargs.get("template", self._template)
        self.mako_template = kwargs.get("mako_template", self._mako_template)
        self.mako_lookup = kwargs.get("template_lookup", self._mako_lookup)

        self.message = message

        self.headers = kwargs.get("headers", [])
        _content_type = kwargs.get("content", self._content_type)
        addContentType = True
        for header in self.headers:
            if "content-type" == header[0].lower():
                addContentType = False
        if addContentType:
            self.headers.append(("Content-type", _content_type))

    def __call__(self, environ, start_response, **kwargs):
        try:
            start_response(self.status, self.headers)
        except TypeError:
            pass
        return self.response(self.message or geturl(environ), **kwargs)

    def _response(self, message="", **argv):
        if self.template:
            message = self.template % message
        elif self.mako_lookup and self.mako_template:
            argv["message"] = message
            mte = self.mako_lookup.get_template(self.mako_template)
            message = mte.render(**argv)

        if isinstance(message, str):
            return [message.encode("utf-8")]
        elif isinstance(message, bytes):
            return [message]
        else:
            return message

    def add_header(self, ava):
        """
        Does *NOT* replace a header of the same type, just adds a new
        :param ava: (type, value) tuple
        """
        self.headers.append(ava)

    def reply(self, **kwargs):
        return self.response(self.message, **kwargs)


class Created(Response):
    _status = "201 Created"


class Redirect(Response):
    _template = (
        "<html>\n<head><title>Redirecting to %s</title></head>\n"
        '<body>\nYou are being redirected to <a href="%s">%s</a>\n'
        "</body>\n</html>"
    )
    _status = "302 Found"

    def __call__(self, environ, start_response, **kwargs):
        location = self.message
        self.headers.append(("location", location))
        start_response(self.status, self.headers)
        return self.response((location, location, location))


class SeeOther(Response):
    _template = (
        "<html>\n<head><title>Redirecting to %s</title></head>\n"
        '<body>\nYou are being redirected to <a href="%s">%s</a>\n'
        "</body>\n</html>"
    )
    _status = "303 See Other"

    def __call__(self, environ, start_response, **kwargs):
        location = ""
        if self.message:
            location = self.message
            self.headers.append(("location", location))
        else:
            for param, item in self.headers:
                if param == "location":
                    location = item
                    break
        start_response(self.status, self.headers)
        return self.response((location, location, location))


class Forbidden(Response):
    _status = "403 Forbidden"
    _template = "<html>Not allowed to mess with: '%s'</html>"


class BadRequest(Response):
    _status = "400 Bad Request"
    _template = "<html>%s</html>"


class Unauthorized(Response):
    _status = "401 Unauthorized"
    _template = "<html>%s</html>"


class NotFound(Response):
    _status = "404 NOT FOUND"


class NotAcceptable(Response):
    _status = "406 Not Acceptable"


class ServiceError(Response):
    _status = "500 Internal Service Error"


class NotImplemented(Response):
    _status = "501 Not Implemented"
    # override template since we need an environment variable
    template = "The request method %s is not implemented " "for this server.\r\n%s"


class BadGateway(Response):
    _status = "502 Bad Gateway"


class HttpParameters:
    """GET or POST signature parameters for Redirect or POST-SimpleSign bindings
    because they are not contained in XML unlike the POST binding
    """

    signature = None
    sigalg = None
    # Relaystate and SAML message are stored elsewhere

    def __init__(self, dict):
        try:
            self.signature = dict["Signature"][0]
            self.sigalg = dict["SigAlg"][0]
        except KeyError:
            pass


def extract(environ, empty=False, err=False):
    """Extracts strings in form data and returns a dict.

    :param environ: WSGI environ
    :param empty: Stops on empty fields (default: Fault)
    :param err: Stops on errors in fields (default: Fault)
    """
    formdata = cgi.parse(environ["wsgi.input"], environ, empty, err)
    # Remove single entries from lists
    for key, value in iter(formdata.items()):
        if len(value) == 1:
            formdata[key] = value[0]
    return formdata


def geturl(environ, query=True, path=True, use_server_name=False):
    """Rebuilds a request URL (from PEP 333).
    You may want to chose to use the environment variables
    server_name and server_port instead of http_host in some case.
    The parameter use_server_name allows you to chose.

    :param query: Is QUERY_STRING included in URI (default: True)
    :param path: Is path included in URI (default: True)
    :param use_server_name: If SERVER_NAME/_HOST should be used instead of
        HTTP_HOST
    """
    url = [f"{environ['wsgi.url_scheme']}://"]
    if use_server_name:
        url.append(environ["SERVER_NAME"])
        if environ["wsgi.url_scheme"] == "https":
            if environ["SERVER_PORT"] != "443":
                url.append(f":{environ['SERVER_PORT']}")
        else:
            if environ["SERVER_PORT"] != "80":
                url.append(f":{environ['SERVER_PORT']}")
    else:
        url.append(environ["HTTP_HOST"])
    if path:
        url.append(getpath(environ))
    if query and environ.get("QUERY_STRING"):
        url.append(f"?{environ['QUERY_STRING']}")
    return "".join(url)


def getpath(environ):
    """Builds a path."""
    return "".join([quote(environ.get("SCRIPT_NAME", "")), quote(environ.get("PATH_INFO", ""))])


def get_post(environ):
    # the environment variable CONTENT_LENGTH may be empty or missing
    try:
        request_body_size = int(environ.get("CONTENT_LENGTH", 0))
    except ValueError:
        request_body_size = 0

    # When the method is POST the query string will be sent
    # in the HTTP request body which is passed by the WSGI server
    # in the file like wsgi.input environment variable.
    return environ["wsgi.input"].read(request_body_size)


def get_response(environ, start_response):
    if environ.get("REQUEST_METHOD") == "GET":
        query = environ.get("QUERY_STRING")
    elif environ.get("REQUEST_METHOD") == "POST":
        query = get_post(environ)
    else:
        resp = BadRequest("Unsupported method")
        return resp(environ, start_response)

    return query


def unpack_redirect(environ):
    if "QUERY_STRING" in environ:
        _qs = environ["QUERY_STRING"]
        return {k: v[0] for k, v in parse_qs(_qs).items()}
    else:
        return None


def unpack_post(environ):
    return {k: v[0] for k, v in parse_qs(get_post(environ))}


def unpack_soap(environ):
    try:
        query = get_post(environ)
        return {"SAMLRequest": query, "RelayState": ""}
    except Exception:
        return None


def unpack_artifact(environ):
    if environ["REQUEST_METHOD"] == "GET":
        _dict = unpack_redirect(environ)
    elif environ["REQUEST_METHOD"] == "POST":
        _dict = unpack_post(environ)
    else:
        _dict = None
    return _dict


def unpack_any(environ):
    if environ["REQUEST_METHOD"].upper() == "GET":
        # Could be either redirect or artifact
        _dict = unpack_redirect(environ)
        if "ID" in _dict:
            binding = BINDING_URI
        elif "SAMLart" in _dict:
            binding = BINDING_HTTP_ARTIFACT
        else:
            binding = BINDING_HTTP_REDIRECT
    else:
        content_type = environ.get("CONTENT_TYPE", "application/soap+xml")
        if content_type != "application/soap+xml":
            # normal post
            _dict = unpack_post(environ)
            if "SAMLart" in _dict:
                binding = BINDING_HTTP_ARTIFACT
            else:
                binding = BINDING_HTTP_POST
        else:
            _dict = unpack_soap(environ)
            binding = BINDING_SOAP

    return _dict, binding


def _expiration(timeout, time_format=None):
    if timeout == "now":
        return time_util.instant(time_format)
    else:
        # validity time should match lifetime of assertions
        return time_util.in_a_while(minutes=timeout, format=time_format)


def cookie_signature(seed, *parts):
    """Generates a cookie signature."""
    sha1 = hmac.new(seed, digestmod=hashlib.sha1)
    for part in parts:
        if part:
            sha1.update(part)
    return sha1.hexdigest()


def make_cookie(name, load, seed, expire=0, domain="", path="", timestamp=""):
    """
    Create and return a cookie

    :param name: Cookie name
    :param load: Cookie load
    :param seed: A seed for the HMAC function
    :param expire: Number of minutes before this cookie goes stale
    :param domain: The domain of the cookie
    :param path: The path specification for the cookie
    :return: A tuple to be added to headers
    """
    cookie = SimpleCookie()
    if not timestamp:
        timestamp = str(int(time.mktime(time.gmtime())))
    signature = cookie_signature(seed, load, timestamp)
    cookie[name] = "|".join([load, timestamp, signature])
    if path:
        cookie[name]["path"] = path
    if domain:
        cookie[name]["domain"] = domain
    if expire:
        cookie[name]["expires"] = _expiration(expire, "%a, %d-%b-%Y %H:%M:%S GMT")

    return tuple(cookie.output().split(": ", 1))


def parse_cookie(name, seed, kaka):
    """Parses and verifies a cookie value

    :param seed: A seed used for the HMAC signature
    :param kaka: The cookie
    :return: A tuple consisting of (payload, timestamp)
    """
    if not kaka:
        return None

    cookie_obj = SimpleCookie(kaka)
    morsel = cookie_obj.get(name)

    if morsel:
        parts = morsel.value.split("|")
        if len(parts) != 3:
            return None
            # verify the cookie signature
        sig = cookie_signature(seed, parts[0], parts[1])
        if sig != parts[2]:
            raise SAMLError("Invalid cookie signature")

        try:
            return parts[0].strip(), parts[1]
        except KeyError:
            return None
    else:
        return None


def cookie_parts(name, kaka):
    cookie_obj = SimpleCookie(kaka)
    morsel = cookie_obj.get(name)
    if morsel:
        return morsel.value.split("|")
    else:
        return None