diff options
Diffstat (limited to 'redis/connection.py')
-rw-r--r-- | redis/connection.py | 153 |
1 files changed, 107 insertions, 46 deletions
diff --git a/redis/connection.py b/redis/connection.py index 80945fa..2fb5663 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -27,7 +27,6 @@ if HIREDIS_AVAILABLE: SYM_STAR = b('*') SYM_DOLLAR = b('$') SYM_CRLF = b('\r\n') -SYM_LF = b('\n') SYM_EMPTY = b('') @@ -48,13 +47,89 @@ class BaseParser(object): return ResponseError(response) +class SocketBuffer(object): + MAX_READ_LENGTH = 1000000 + + def __init__(self, socket): + self._sock = socket + self._buffer = BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read + + def _read_from_socket(self, length=None): + chunksize = 8192 + if length is None: + length = chunksize + + buf = self._buffer + buf.seek(self.bytes_written) + marker = 0 + + try: + while length > marker: + data = self._sock.recv(chunksize) + buf.write(data) + self.bytes_written += len(data) + marker += chunksize + except (socket.error, socket.timeout): + e = sys.exc_info()[1] + raise ConnectionError("Error while reading from socket: %s" % + (e.args,)) + + def read(self, length): + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + self._read_from_socket(length - self.length) + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def readline(self): + buf = self._buffer + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + return data[:-2] + + def purge(self): + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + self.purge() + self._buffer.close() + + class PythonParser(BaseParser): "Plain Python parsing class" - MAX_READ_LENGTH = 1000000 encoding = None def __init__(self): - self._fp = None + self._sock = None + self._buffer = None def __del__(self): try: @@ -64,53 +139,25 @@ class PythonParser(BaseParser): def on_connect(self, connection): "Called when the socket connects" - self._fp = connection._sock.makefile('rb') + self._sock = connection._sock + self._buffer = SocketBuffer(self._sock) if connection.decode_responses: self.encoding = connection.encoding def on_disconnect(self): "Called when the socket disconnects" - if self._fp is not None: - self._fp.close() - self._fp = None + if self._sock is not None: + self._sock.close() + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None def can_read(self): - return self._fp and self._fp._rbuf.tell() - - def read(self, length=None): - """ - Read a line from the socket if no length is specified, - otherwise read ``length`` bytes. Always strip away the newlines. - """ - try: - if length is not None: - bytes_left = length + 2 # read the line ending - if length > self.MAX_READ_LENGTH: - # apparently reading more than 1MB or so from a windows - # socket can cause MemoryErrors. See: - # https://github.com/andymccurdy/redis-py/issues/205 - # read smaller chunks at a time to work around this - try: - buf = BytesIO() - while bytes_left > 0: - read_len = min(bytes_left, self.MAX_READ_LENGTH) - buf.write(self._fp.read(read_len)) - bytes_left -= read_len - buf.seek(0) - return buf.read(length) - finally: - buf.close() - return self._fp.read(bytes_left)[:-2] - - # no length, read a full line - return self._fp.readline()[:-2] - except (socket.error, socket.timeout): - e = sys.exc_info()[1] - raise ConnectionError("Error while reading from socket: %s" % - (e.args,)) + return self._buffer and bool(self._buffer.length) def read_response(self): - response = self.read() + response = self._buffer.readline() if not response: raise ConnectionError("Socket closed on remote end") @@ -144,7 +191,7 @@ class PythonParser(BaseParser): length = int(response) if length == -1: return None - response = self.read(length) + response = self._buffer.read(length) # multi-bulk response elif byte == '*': length = int(response) @@ -161,7 +208,6 @@ class HiredisParser(BaseParser): def __init__(self): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") - self._next_response = False def __del__(self): try: @@ -178,10 +224,12 @@ class HiredisParser(BaseParser): if connection.decode_responses: kwargs['encoding'] = connection.encoding self._reader = hiredis.Reader(**kwargs) + self._next_response = False def on_disconnect(self): self._sock = None self._reader = None + self._next_response = False def can_read(self): if not self._reader: @@ -213,8 +261,8 @@ class HiredisParser(BaseParser): raise ConnectionError("Socket closed on remote end") self._reader.feed(buffer) # proactively, but not conclusively, check if more data is in the - # buffer. if the data received doesn't end with \n, there's more. - if not buffer.endswith(SYM_LF): + # buffer. if the data received doesn't end with \r\n, there's more. + if not buffer.endswith(SYM_CRLF): continue response = self._reader.gets() if isinstance(response, ResponseError): @@ -256,6 +304,7 @@ class Connection(object): 'port': self.port, 'db': self.db, } + self._connect_callbacks = [] def __repr__(self): return self.description_format % self._description_args @@ -266,6 +315,12 @@ class Connection(object): except Exception: pass + def register_connect_callback(self, callback): + self._connect_callbacks.append(callback) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + def connect(self): "Connects to the Redis server if not already connected" if self._sock: @@ -284,6 +339,11 @@ class Connection(object): self.disconnect() raise + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for callback in self._connect_callbacks: + callback(self) + def _connect(self): "Create a TCP socket connection" # in 2.6+ try to use IPv6/4 compatibility, else just original code @@ -361,7 +421,8 @@ class Connection(object): "Poll the socket to see if there's data that can be read." sock = self._sock if not sock: - return False + self.connect() + sock = self._sock return bool(select([sock], [], [], 0)[0]) or self._parser.can_read() def read_response(self): |