summaryrefslogtreecommitdiff
path: root/lib/py/src/transport/TSocket.py
blob: 50ee67e76659923fd90e15e3f404ac9efb9195d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import errno
import logging
import os
import socket
import sys

from .TTransport import TTransportBase, TTransportException, TServerTransportBase

logger = logging.getLogger(__name__)


class TSocketBase(TTransportBase):
    def _resolveAddr(self):
        if self._unix_socket is not None:
            return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
                     self._unix_socket)]
        else:
            return socket.getaddrinfo(self.host,
                                      self.port,
                                      self._socket_family,
                                      socket.SOCK_STREAM,
                                      0,
                                      socket.AI_PASSIVE)

    def close(self):
        if self.handle:
            self.handle.close()
            self.handle = None


class TSocket(TSocketBase):
    """Socket implementation of TTransport base."""

    def __init__(self, host='localhost', port=9090, unix_socket=None,
                 socket_family=socket.AF_UNSPEC,
                 socket_keepalive=False):
        """Initialize a TSocket

        @param host(str)  The host to connect to.
        @param port(int)  The (TCP) port to connect to.
        @param unix_socket(str)  The filename of a unix socket to connect to.
                                 (host and port will be ignored.)
        @param socket_family(int)  The socket family to use with this socket.
        @param socket_keepalive(bool) enable TCP keepalive, default off.
        """
        self.host = host
        self.port = port
        self.handle = None
        self._unix_socket = unix_socket
        self._timeout = None
        self._socket_family = socket_family
        self._socket_keepalive = socket_keepalive

    def setHandle(self, h):
        self.handle = h

    def isOpen(self):
        if self.handle is None:
            return False

        # this lets us cheaply see if the other end of the socket is still
        # connected. if disconnected, we'll get EOF back (expressed as zero
        # bytes of data) otherwise we'll get one byte or an error indicating
        # we'd have to block for data.
        #
        # note that we're not doing this with socket.MSG_DONTWAIT because 1)
        # it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us
        # when timeout is non-zero.
        original_timeout = self.handle.gettimeout()
        try:
            self.handle.settimeout(0)
            try:
                peeked_bytes = self.handle.recv(1, socket.MSG_PEEK)
            except (socket.error, OSError) as exc:  # on modern python this is just BlockingIOError
                if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
                    return True
                return False
            except ValueError:
                # SSLSocket fails on recv with non-zero flags; fallback to the old behavior
                return True
        finally:
            self.handle.settimeout(original_timeout)

        # the length will be zero if we got EOF (indicating connection closed)
        return len(peeked_bytes) == 1

    def setTimeout(self, ms):
        if ms is None:
            self._timeout = None
        else:
            self._timeout = ms / 1000.0

        if self.handle is not None:
            self.handle.settimeout(self._timeout)

    def _do_open(self, family, socktype):
        return socket.socket(family, socktype)

    @property
    def _address(self):
        return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)

    def open(self):
        if self.handle:
            raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
        try:
            addrs = self._resolveAddr()
        except socket.gaierror as gai:
            msg = 'failed to resolve sockaddr for ' + str(self._address)
            logger.exception(msg)
            raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
        for family, socktype, _, _, sockaddr in addrs:
            handle = self._do_open(family, socktype)

            # TCP keep-alive
            if self._socket_keepalive:
                handle.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

            handle.settimeout(self._timeout)
            try:
                handle.connect(sockaddr)
                self.handle = handle
                return
            except socket.error:
                handle.close()
                logger.info('Could not connect to %s', sockaddr, exc_info=True)
        msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
                                                          addrs))
        logger.error(msg)
        raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)

    def read(self, sz):
        try:
            buff = self.handle.recv(sz)
        except socket.error as e:
            if (e.args[0] == errno.ECONNRESET and
                    (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
                # freebsd and Mach don't follow POSIX semantic of recv
                # and fail with ECONNRESET if peer performed shutdown.
                # See corresponding comment and code in TSocket::read()
                # in lib/cpp/src/transport/TSocket.cpp.
                self.close()
                # Trigger the check to raise the END_OF_FILE exception below.
                buff = ''
            elif e.args[0] == errno.ETIMEDOUT:
                raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
            else:
                raise TTransportException(message="unexpected exception", inner=e)
        if len(buff) == 0:
            raise TTransportException(type=TTransportException.END_OF_FILE,
                                      message='TSocket read 0 bytes')
        return buff

    def write(self, buff):
        if not self.handle:
            raise TTransportException(type=TTransportException.NOT_OPEN,
                                      message='Transport not open')
        sent = 0
        have = len(buff)
        while sent < have:
            try:
                plus = self.handle.send(buff)
                if plus == 0:
                    raise TTransportException(type=TTransportException.END_OF_FILE,
                                              message='TSocket sent 0 bytes')
                sent += plus
                buff = buff[plus:]
            except socket.error as e:
                raise TTransportException(message="unexpected exception", inner=e)

    def flush(self):
        pass


class TServerSocket(TSocketBase, TServerTransportBase):
    """Socket implementation of TServerTransport base."""

    def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
        self.host = host
        self.port = port
        self._unix_socket = unix_socket
        self._socket_family = socket_family
        self.handle = None
        self._backlog = 128

    def setBacklog(self, backlog=None):
        if not self.handle:
            self._backlog = backlog
        else:
            # We cann't update backlog when it is already listening, since the
            # handle has been created.
            logger.warn('You have to set backlog before listen.')

    def listen(self):
        res0 = self._resolveAddr()
        socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
        for res in res0:
            if res[0] is socket_family or res is res0[-1]:
                break

        # We need remove the old unix socket if the file exists and
        # nobody is listening on it.
        if self._unix_socket:
            tmp = socket.socket(res[0], res[1])
            try:
                tmp.connect(res[4])
            except socket.error as err:
                eno, message = err.args
                if eno == errno.ECONNREFUSED:
                    os.unlink(res[4])

        self.handle = s = socket.socket(res[0], res[1])
        if s.family is socket.AF_INET6:
            s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        if hasattr(s, 'settimeout'):
            s.settimeout(None)
        s.bind(res[4])
        s.listen(self._backlog)

    def accept(self):
        client, addr = self.handle.accept()
        result = TSocket()
        result.setHandle(client)
        return result