summaryrefslogtreecommitdiff
path: root/paste/cascade.py
blob: d939e4a2da31a9e7391f435d7222e478207d668b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php

"""
Cascades through several applications, so long as applications
return ``404 Not Found``.
"""
from paste import httpexceptions
from paste.util import converters
import tempfile
from cStringIO import StringIO

__all__ = ['Cascade']

def make_cascade(loader, global_conf, catch='404', **local_conf):
    """
    Expects configuration like:

    [composit:cascade]
    use = egg:Paste#cascade
    # all start with 'app' and are sorted alphabetically
    app1 = foo
    app2 = bar
    ...
    catch = 404 500 ...
    """
    catch = map(int, converters.aslist(catch))
    apps = []
    for name, value in local_conf.items():
        if not name.startswith('app'):
            raise ValueError(
                "Bad configuration key %r (=%r); all configuration keys "
                "must start with 'app'"
                % (name, value))
        app = loader.get_app(value, global_conf=global_conf)
        apps.append((name, app))
    apps.sort()
    apps = [app for name, app in apps]
    return Cascade(apps, catch=catch)
    
class Cascade(object):

    """
    Passed a list of applications, ``Cascade`` will try each of them
    in turn.  If one returns a status code listed in ``catch`` (by
    default just ``404 Not Found``) then the next application is
    tried.

    If all applications fail, then the last application's failure
    response is used.
    """

    def __init__(self, applications, catch=(404,)):
        self.apps = applications
        self.catch_codes = {}
        self.catch_exceptions = []
        for error in catch:
            if isinstance(error, str):
                error = int(error.split(None, 1)[0])
            if isinstance(error, httpexceptions.HTTPException):
                exc = error
                code = error.code
            else:
                exc = httpexceptions.get_exception(error)
                code = error
            self.catch_codes[code] = exc
            self.catch_exceptions.append(exc)
        self.catch_exceptions = tuple(self.catch_exceptions)
                
    def __call__(self, environ, start_response):
        failed = []
        def repl_start_response(status, headers, exc_info=None):
            code = int(status.split(None, 1)[0])
            if code in self.catch_codes:
                failed.append(None)
                return _consuming_writer
            return start_response(status, headers, exc_info)

        length = int(environ.get('CONTENT_LENGTH', '0'))
        if length > 0:
            # We have to copy wsgi.input
            copy_wsgi_input = True
            if length > 4096 or length < 0:
                f = tempfile.TemporaryFile()
                if length < 0:
                    f.write(environ['wsgi.input'].read())
                else:
                    copy_len = length
                    while copy_len > 0:
                        chunk = environ['wsgi.input'].read(min(copy_len, 4096))
                        f.write(chunk)
                        copy_len -= len(chunk)
                f.seek(0)
            else:
                f = StringIO(environ['wsgi.input'].read(length))
            environ['wsgi.input'] = f
        else:
            copy_wsgi_input = False
        for app in self.apps[:-1]:
            environ_copy = environ.copy()
            if copy_wsgi_input:
                environ_copy['wsgi.input'].seek(0)
            failed = []
            try:
                v = app(environ_copy, repl_start_response)
                if not failed:
                    return v
                else:
                    if hasattr(v, 'close'):
                        # Exhaust the iterator first:
                        list(v)
                        # then close:
                        v.close()
            except self.catch_exceptions, e:
                pass
        return self.apps[-1](environ, start_response)

def _consuming_writer(s):
    pass