summaryrefslogtreecommitdiff
path: root/oslo_rootwrap/jsonrpc.py
blob: 195febb1f361c4d780408262606327a9059478c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) 2014 Mirantis Inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import base64
import errno
import json
from multiprocessing import connection
from multiprocessing import managers
import socket
import struct
import weakref

from oslo_rootwrap import wrapper


class RpcJSONEncoder(json.JSONEncoder):
    def default(self, o):
        # We need to pass bytes unchanged as they are expected in arguments for
        # and are result of Popen.communicate()
        if isinstance(o, bytes):
            return {"__bytes__": base64.b64encode(o).decode('ascii')}
        # Handle two exception types relevant to command execution
        if isinstance(o, wrapper.NoFilterMatched):
            return {"__exception__": "NoFilterMatched"}
        elif isinstance(o, wrapper.FilterMatchNotExecutable):
            return {"__exception__": "FilterMatchNotExecutable",
                    "match": o.match}
        # Other errors will fail to pass JSON encoding and will be visible on
        # client side
        else:
            return super(RpcJSONEncoder, self).default(o)


# Parse whatever RpcJSONEncoder supplied us with
def rpc_object_hook(obj):
    if "__exception__" in obj:
        type_name = obj.pop("__exception__")
        if type_name not in ("NoFilterMatched", "FilterMatchNotExecutable"):
            return obj
        exc_type = getattr(wrapper, type_name)
        return exc_type(**obj)
    elif "__bytes__" in obj:
        return base64.b64decode(obj["__bytes__"].encode('ascii'))
    else:
        return obj


class JsonListener(object):
    def __init__(self, address, backlog=1):
        self.address = address
        self._socket = socket.socket(socket.AF_UNIX)
        try:
            self._socket.setblocking(True)
            self._socket.bind(address)
            self._socket.listen(backlog)
        except socket.error:
            self._socket.close()
            raise
        self.closed = False
        self._accepted = weakref.WeakSet()

    def accept(self):
        while True:
            try:
                s, _ = self._socket.accept()
            except socket.error as e:
                if e.errno in (errno.EINVAL, errno.EBADF):
                    raise EOFError
                elif e.errno != errno.EINTR:
                    raise
            else:
                break
        s.setblocking(True)
        conn = JsonConnection(s)
        self._accepted.add(conn)
        return conn

    def close(self):
        if not self.closed:
            self._socket.shutdown(socket.SHUT_RDWR)
            self._socket.close()
            self.closed = True

    def get_accepted(self):
        return self._accepted


if hasattr(managers.Server, 'accepter'):
    # In Python 3 accepter() thread has infinite loop. We break it with
    # EOFError, so we should silence this error here.
    def silent_accepter(self):
        try:
            old_accepter(self)
        except EOFError:
            pass
    old_accepter = managers.Server.accepter
    managers.Server.accepter = silent_accepter


class JsonConnection(object):
    def __init__(self, sock):
        sock.setblocking(True)
        self._socket = sock

    def send_bytes(self, s):
        self._socket.sendall(struct.pack('!Q', len(s)))
        self._socket.sendall(s)

    def recv_bytes(self, maxsize=None):
        item = struct.unpack('!Q', self.recvall(8))[0]
        if maxsize is not None and item > maxsize:
            raise RuntimeError("Too big message received")
        s = self.recvall(item)
        return s

    def send(self, obj):
        s = self.dumps(obj)
        self.send_bytes(s)

    def recv(self):
        s = self.recv_bytes()
        return self.loads(s)

    def close(self):
        self._socket.close()

    def half_close(self):
        self._socket.shutdown(socket.SHUT_RD)

    # We have to use slow version of recvall with eventlet because of a bug in
    # GreenSocket.recv_into:
    # https://bitbucket.org/eventlet/eventlet/pull-request/41
    def _recvall_slow(self, size):
        remaining = size
        res = []
        while remaining:
            piece = self._socket.recv(remaining)
            if not piece:
                raise EOFError
            res.append(piece)
            remaining -= len(piece)
        return b''.join(res)

    def recvall(self, size):
        buf = bytearray(size)
        mem = memoryview(buf)
        got = 0
        while got < size:
            piece_size = self._socket.recv_into(mem[got:])
            if not piece_size:
                raise EOFError
            got += piece_size
        # bytearray is mostly compatible with bytes and we could avoid copying
        # data here, but hmac doesn't like it in Python 3.3 (not in 2.7 or 3.4)
        return bytes(buf)

    @staticmethod
    def dumps(obj):
        return json.dumps(obj, cls=RpcJSONEncoder).encode('utf-8')

    @staticmethod
    def loads(s):
        res = json.loads(s.decode('utf-8'), object_hook=rpc_object_hook)
        try:
            kind = res[0]
        except (IndexError, TypeError):
            pass
        else:
            # In Python 2 json returns unicode while multiprocessing needs str
            if (kind in ("#TRACEBACK", "#UNSERIALIZABLE") and
                    not isinstance(res[1], str)):
                res[1] = res[1].encode('utf-8', 'replace')
        return res


class JsonClient(JsonConnection):
    def __init__(self, address, authkey=None):
        sock = socket.socket(socket.AF_UNIX)
        sock.setblocking(True)
        sock.connect(address)
        super(JsonClient, self).__init__(sock)
        if authkey is not None:
            connection.answer_challenge(self, authkey)
            connection.deliver_challenge(self, authkey)