summaryrefslogtreecommitdiff
path: root/amqp/sasl.py
blob: 407ccb8e27322430e8e178da98d584198976204d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""SASL mechanisms for AMQP authentication."""

import socket
import warnings
from io import BytesIO

from amqp.serialization import _write_table


class SASL:
    """The base class for all amqp SASL authentication mechanisms.

    You should sub-class this if you're implementing your own authentication.
    """

    @property
    def mechanism(self):
        """Return a bytes containing the SASL mechanism name."""
        raise NotImplementedError

    def start(self, connection):
        """Return the first response to a SASL challenge as a bytes object."""
        raise NotImplementedError


class PLAIN(SASL):
    """PLAIN SASL authentication mechanism.

    See https://tools.ietf.org/html/rfc4616 for details
    """

    mechanism = b'PLAIN'

    def __init__(self, username, password):
        self.username, self.password = username, password

    __slots__ = (
        "username",
        "password",
        )

    def start(self, connection):
        if self.username is None or self.password is None:
            return NotImplemented
        login_response = BytesIO()
        login_response.write(b'\0')
        login_response.write(self.username.encode('utf-8'))
        login_response.write(b'\0')
        login_response.write(self.password.encode('utf-8'))
        return login_response.getvalue()


class AMQPLAIN(SASL):
    """AMQPLAIN SASL authentication mechanism.

    This is a non-standard mechanism used by AMQP servers.
    """

    mechanism = b'AMQPLAIN'

    def __init__(self, username, password):
        self.username, self.password = username, password

    __slots__ = (
        "username",
        "password",
        )

    def start(self, connection):
        if self.username is None or self.password is None:
            return NotImplemented
        login_response = BytesIO()
        _write_table({b'LOGIN': self.username, b'PASSWORD': self.password},
                     login_response.write, [])
        # Skip the length at the beginning
        return login_response.getvalue()[4:]


def _get_gssapi_mechanism():
    try:
        import gssapi
        import gssapi.raw.misc  # Fail if the old python-gssapi is installed
    except ImportError:
        class FakeGSSAPI(SASL):
            """A no-op SASL mechanism for when gssapi isn't available."""

            mechanism = None

            def __init__(self, client_name=None, service=b'amqp',
                         rdns=False, fail_soft=False):
                if not fail_soft:
                    raise NotImplementedError(
                        "You need to install the `gssapi` module for GSSAPI "
                        "SASL support")

            def start(self):  # pragma: no cover
                return NotImplemented
        return FakeGSSAPI
    else:
        class GSSAPI(SASL):
            """GSSAPI SASL authentication mechanism.

            See https://tools.ietf.org/html/rfc4752 for details
            """

            mechanism = b'GSSAPI'

            def __init__(self, client_name=None, service=b'amqp',
                         rdns=False, fail_soft=False):
                if client_name and not isinstance(client_name, bytes):
                    client_name = client_name.encode('ascii')
                self.client_name = client_name
                self.fail_soft = fail_soft
                self.service = service
                self.rdns = rdns

            __slots__ = (
                "client_name",
                "fail_soft",
                "service",
                "rdns"
                )

            def get_hostname(self, connection):
                sock = connection.transport.sock
                if self.rdns and sock.family in (socket.AF_INET,
                                                 socket.AF_INET6):
                    peer = sock.getpeername()
                    hostname, _, _ = socket.gethostbyaddr(peer[0])
                else:
                    hostname = connection.transport.host
                if not isinstance(hostname, bytes):
                    hostname = hostname.encode('ascii')
                return hostname

            def start(self, connection):
                try:
                    if self.client_name:
                        creds = gssapi.Credentials(
                            name=gssapi.Name(self.client_name))
                    else:
                        creds = None
                    hostname = self.get_hostname(connection)
                    name = gssapi.Name(b'@'.join([self.service, hostname]),
                                       gssapi.NameType.hostbased_service)
                    context = gssapi.SecurityContext(name=name, creds=creds)
                    return context.step(None)
                except gssapi.raw.misc.GSSError:
                    if self.fail_soft:
                        return NotImplemented
                    else:
                        raise
        return GSSAPI


GSSAPI = _get_gssapi_mechanism()


class EXTERNAL(SASL):
    """EXTERNAL SASL mechanism.

    Enables external authentication, i.e. not handled through this protocol.
    Only passes 'EXTERNAL' as authentication mechanism, but no further
    authentication data.
    """

    mechanism = b'EXTERNAL'

    def start(self, connection):
        return b''


class RAW(SASL):
    """A generic custom SASL mechanism.

    This mechanism takes a mechanism name and response to send to the server,
    so can be used for simple custom authentication schemes.
    """

    mechanism = None

    def __init__(self, mechanism, response):
        assert isinstance(mechanism, bytes)
        assert isinstance(response, bytes)
        self.mechanism, self.response = mechanism, response
        warnings.warn("Passing login_method and login_response to Connection "
                      "is deprecated. Please implement a SASL subclass "
                      "instead.", DeprecationWarning)

    def start(self, connection):
        return self.response