summaryrefslogtreecommitdiff
path: root/redis/parsers/socket.py
blob: 8147243bba23d95e285d607c2f2e411a47dede99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import errno
import io
import socket
from io import SEEK_END
from typing import Optional, Union

from ..exceptions import ConnectionError, TimeoutError
from ..utils import SSL_AVAILABLE

NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK}

if SSL_AVAILABLE:
    import ssl

    if hasattr(ssl, "SSLWantReadError"):
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2
    else:
        NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2

NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())

SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
SENTINEL = object()

SYM_CRLF = b"\r\n"


class SocketBuffer:
    def __init__(
        self, socket: socket.socket, socket_read_size: int, socket_timeout: float
    ):
        self._sock = socket
        self.socket_read_size = socket_read_size
        self.socket_timeout = socket_timeout
        self._buffer = io.BytesIO()

    def unread_bytes(self) -> int:
        """
        Remaining unread length of buffer
        """
        pos = self._buffer.tell()
        end = self._buffer.seek(0, SEEK_END)
        self._buffer.seek(pos)
        return end - pos

    def _read_from_socket(
        self,
        length: Optional[int] = None,
        timeout: Union[float, object] = SENTINEL,
        raise_on_timeout: Optional[bool] = True,
    ) -> bool:
        sock = self._sock
        socket_read_size = self.socket_read_size
        marker = 0
        custom_timeout = timeout is not SENTINEL

        buf = self._buffer
        current_pos = buf.tell()
        buf.seek(0, SEEK_END)
        if custom_timeout:
            sock.settimeout(timeout)
        try:
            while True:
                data = self._sock.recv(socket_read_size)
                # an empty string indicates the server shutdown the socket
                if isinstance(data, bytes) and len(data) == 0:
                    raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
                buf.write(data)
                data_length = len(data)
                marker += data_length

                if length is not None and length > marker:
                    continue
                return True
        except socket.timeout:
            if raise_on_timeout:
                raise TimeoutError("Timeout reading from socket")
            return False
        except NONBLOCKING_EXCEPTIONS as ex:
            # if we're in nonblocking mode and the recv raises a
            # blocking error, simply return False indicating that
            # there's no data to be read. otherwise raise the
            # original exception.
            allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
            if not raise_on_timeout and ex.errno == allowed:
                return False
            raise ConnectionError(f"Error while reading from socket: {ex.args}")
        finally:
            buf.seek(current_pos)
            if custom_timeout:
                sock.settimeout(self.socket_timeout)

    def can_read(self, timeout: float) -> bool:
        return bool(self.unread_bytes()) or self._read_from_socket(
            timeout=timeout, raise_on_timeout=False
        )

    def read(self, length: int) -> bytes:
        length = length + 2  # make sure to read the \r\n terminator
        # BufferIO will return less than requested if buffer is short
        data = self._buffer.read(length)
        missing = length - len(data)
        if missing:
            # fill up the buffer and read the remainder
            self._read_from_socket(missing)
            data += self._buffer.read(missing)
        return data[:-2]

    def readline(self) -> bytes:
        buf = self._buffer
        data = buf.readline()
        while not data.endswith(SYM_CRLF):
            # there's more data in the socket that we need
            self._read_from_socket()
            data += buf.readline()

        return data[:-2]

    def get_pos(self) -> int:
        """
        Get current read position
        """
        return self._buffer.tell()

    def rewind(self, pos: int) -> None:
        """
        Rewind the buffer to a specific position, to re-start reading
        """
        self._buffer.seek(pos)

    def purge(self) -> None:
        """
        After a successful read, purge the read part of buffer
        """
        unread = self.unread_bytes()

        # Only if we have read all of the buffer do we truncate, to
        # reduce the amount of memory thrashing.  This heuristic
        # can be changed or removed later.
        if unread > 0:
            return

        if unread > 0:
            # move unread data to the front
            view = self._buffer.getbuffer()
            view[:unread] = view[-unread:]
        self._buffer.truncate(unread)
        self._buffer.seek(0)

    def close(self) -> None:
        try:
            self._buffer.close()
        except Exception:
            # issue #633 suggests the purge/close somehow raised a
            # BadFileDescriptor error. Perhaps the client ran out of
            # memory or something else? It's probably OK to ignore
            # any error being raised from purge/close since we're
            # removing the reference to the instance below.
            pass
        self._buffer = None
        self._sock = None