summaryrefslogtreecommitdiff
path: root/tests/testutils/http_server.py
blob: 8591159f83872989b0221e8a1ab5cb5f85d39820 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import multiprocessing
import os
import posixpath
import html
import base64
from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus


class Unauthorized(Exception):
    pass


class RequestHandler(SimpleHTTPRequestHandler):
    def get_root_dir(self):
        authorization = self.headers.get("authorization")
        if not authorization:
            if not self.server.anonymous_dir:
                raise Unauthorized("unauthorized")
            return self.server.anonymous_dir
        else:
            authorization = authorization.split()
            if len(authorization) != 2 or authorization[0].lower() != "basic":
                raise Unauthorized("unauthorized")
            try:
                decoded = base64.decodebytes(authorization[1].encode("ascii"))
                user, password = decoded.decode("ascii").split(":")
                expected_password, directory = self.server.users[user]
                if password == expected_password:
                    return directory
            except:  # noqa
                raise Unauthorized("unauthorized")
            return None

    def unauthorized(self):
        shortmsg, longmsg = self.responses[HTTPStatus.UNAUTHORIZED]
        self.send_response(HTTPStatus.UNAUTHORIZED, shortmsg)
        self.send_header("Connection", "close")

        content = self.error_message_format % {
            "code": HTTPStatus.UNAUTHORIZED,
            "message": html.escape(longmsg, quote=False),
            "explain": html.escape(longmsg, quote=False),
        }
        body = content.encode("UTF-8", "replace")
        self.send_header("Content-Type", self.error_content_type)
        self.send_header("Content-Length", str(len(body)))
        self.send_header("WWW-Authenticate", 'Basic realm="{}"'.format(self.server.realm))
        self.end_headers()
        self.end_headers()

        if self.command != "HEAD" and body:
            self.wfile.write(body)

    def do_GET(self):
        try:
            super().do_GET()
        except Unauthorized:
            self.unauthorized()

    def do_HEAD(self):
        try:
            super().do_HEAD()
        except Unauthorized:
            self.unauthorized()

    def translate_path(self, path):
        path = path.split("?", 1)[0]
        path = path.split("#", 1)[0]
        path = posixpath.normpath(path)
        assert posixpath.isabs(path)
        path = posixpath.relpath(path, "/")
        return os.path.join(self.get_root_dir(), path)


class AuthHTTPServer(HTTPServer):
    def __init__(self, *args, **kwargs):
        self.users = {}
        self.anonymous_dir = None
        self.realm = "Realm"
        super().__init__(*args, **kwargs)


class SimpleHttpServer(multiprocessing.Process):
    def __init__(self):
        super().__init__()
        self.server = AuthHTTPServer(("127.0.0.1", 0), RequestHandler)
        self.started = False

    def start(self):
        self.started = True
        super().start()

    def run(self):
        self.server.serve_forever()

    def stop(self):
        if not self.started:
            return
        self.terminate()
        self.join()

    def allow_anonymous(self, cwd):
        self.server.anonymous_dir = cwd

    def add_user(self, user, password, cwd):
        self.server.users[user] = (password, cwd)

    def base_url(self):
        return "http://127.0.0.1:{}".format(self.server.server_port)