summaryrefslogtreecommitdiff
path: root/tests/testutils/http_server.py
blob: b72e745c56302c91eff0e1ef50a80803ede9ecd2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import multiprocessing
import os
import posixpath
import html
import base64
from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus


class Unauthorized(Exception):
    pass


class RequestHandler(SimpleHTTPRequestHandler):

    def get_root_dir(self):
        authorization = self.headers.get('authorization')
        if not authorization:
            if not self.server.anonymous_dir:
                raise Unauthorized('unauthorized')
            return self.server.anonymous_dir
        else:
            authorization = authorization.split()
            if len(authorization) != 2 or authorization[0].lower() != 'basic':
                raise Unauthorized('unauthorized')
            try:
                decoded = base64.decodebytes(authorization[1].encode('ascii'))
                user, password = decoded.decode('ascii').split(':')
                expected_password, directory = self.server.users[user]
                if password == expected_password:
                    return directory
            except:                           # noqa
                raise Unauthorized('unauthorized')
            return None

    def unauthorized(self):
        shortmsg, longmsg = self.responses[HTTPStatus.UNAUTHORIZED]
        self.send_response(HTTPStatus.UNAUTHORIZED, shortmsg)
        self.send_header('Connection', 'close')

        content = (self.error_message_format % {
            'code': HTTPStatus.UNAUTHORIZED,
            'message': html.escape(longmsg, quote=False),
            'explain': html.escape(longmsg, quote=False)
        })
        body = content.encode('UTF-8', 'replace')
        self.send_header('Content-Type', self.error_content_type)
        self.send_header('Content-Length', str(len(body)))
        self.send_header('WWW-Authenticate', 'Basic realm="{}"'.format(self.server.realm))
        self.end_headers()
        self.end_headers()

        if self.command != 'HEAD' and body:
            self.wfile.write(body)

    def do_GET(self):
        try:
            super().do_GET()
        except Unauthorized:
            self.unauthorized()

    def do_HEAD(self):
        try:
            super().do_HEAD()
        except Unauthorized:
            self.unauthorized()

    def translate_path(self, path):
        path = path.split('?', 1)[0]
        path = path.split('#', 1)[0]
        path = posixpath.normpath(path)
        assert posixpath.isabs(path)
        path = posixpath.relpath(path, '/')
        return os.path.join(self.get_root_dir(), path)


class AuthHTTPServer(HTTPServer):
    def __init__(self, *args, **kwargs):
        self.users = {}
        self.anonymous_dir = None
        self.realm = 'Realm'
        super().__init__(*args, **kwargs)


class SimpleHttpServer(multiprocessing.Process):
    def __init__(self):
        super().__init__()
        self.server = AuthHTTPServer(('127.0.0.1', 0), RequestHandler)
        self.started = False

    def start(self):
        self.started = True
        super().start()

    def run(self):
        self.server.serve_forever()

    def stop(self):
        if not self.started:
            return
        self.terminate()
        self.join()

    def allow_anonymous(self, cwd):
        self.server.anonymous_dir = cwd

    def add_user(self, user, password, cwd):
        self.server.users[user] = (password, cwd)

    def base_url(self):
        return 'http://127.0.0.1:{}'.format(self.server.server_port)