summaryrefslogtreecommitdiff
path: root/pecan/middleware/errordocument.py
blob: a4a70219416938977a6631297daa48e9c0c82f49 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
from recursive import ForwardRequestException, RecursionLoop


class StatusPersist(object):

    def __init__(self, app, status, url):
        self.app = app
        self.status = status
        self.url = url

    def __call__(self, environ, start_response):
        def keep_status_start_response(status, headers, exc_info=None):
            return start_response(self.status, headers, exc_info)
        parts = self.url.split('?')
        environ['PATH_INFO'] = parts[0]
        if len(parts) > 1:
            environ['QUERY_STRING'] = parts[1]
        else:
            environ['QUERY_STRING'] = ''

        try:
            return self.app(environ, keep_status_start_response)
        except RecursionLoop, e:
            environ['wsgi.errors'].write(
                'Recursion error getting error page: %s\n' % e
            )
            keep_status_start_response(
                '500 Server Error',
                [('Content-type', 'text/plain')],
                sys.exc_info()
            )
            return [
                'Error: %s.  (Error page could not be fetched)' % self.status
            ]


class ErrorDocumentMiddleware(object):
    '''
    Intersects HTTP response status code, looks it up in the error map defined
    in the Pecan app config.py, and routes to the controller assigned to that
    status.
    '''
    def __init__(self, app, error_map):
        self.app = app
        self.error_map = error_map

    def __call__(self, environ, start_response):

        def replacement_start_response(status, headers, exc_info=None):
            '''
            Overrides the default response if the status is defined in the
            Pecan app error map configuration.
            '''
            try:
                status_code = int(status.split(' ')[0])
            except (ValueError, TypeError):  # pragma: nocover
                raise Exception((
                    'ErrorDocumentMiddleware received an invalid '
                    'status %s' % status
                ))

            if status_code in self.error_map:
                def factory(app):
                    return StatusPersist(
                        app,
                        status,
                        self.error_map[status_code]
                    )
                raise ForwardRequestException(factory=factory)
            return start_response(status, headers, exc_info)

        app_iter = self.app(environ, replacement_start_response)
        return app_iter