summaryrefslogtreecommitdiff
path: root/redis/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/connection.py')
-rw-r--r--redis/connection.py153
1 files changed, 107 insertions, 46 deletions
diff --git a/redis/connection.py b/redis/connection.py
index 80945fa..2fb5663 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -27,7 +27,6 @@ if HIREDIS_AVAILABLE:
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
-SYM_LF = b('\n')
SYM_EMPTY = b('')
@@ -48,13 +47,89 @@ class BaseParser(object):
return ResponseError(response)
+class SocketBuffer(object):
+ MAX_READ_LENGTH = 1000000
+
+ def __init__(self, socket):
+ self._sock = socket
+ self._buffer = BytesIO()
+ # number of bytes written to the buffer from the socket
+ self.bytes_written = 0
+ # number of bytes read from the buffer
+ self.bytes_read = 0
+
+ @property
+ def length(self):
+ return self.bytes_written - self.bytes_read
+
+ def _read_from_socket(self, length=None):
+ chunksize = 8192
+ if length is None:
+ length = chunksize
+
+ buf = self._buffer
+ buf.seek(self.bytes_written)
+ marker = 0
+
+ try:
+ while length > marker:
+ data = self._sock.recv(chunksize)
+ buf.write(data)
+ self.bytes_written += len(data)
+ marker += chunksize
+ except (socket.error, socket.timeout):
+ e = sys.exc_info()[1]
+ raise ConnectionError("Error while reading from socket: %s" %
+ (e.args,))
+
+ def read(self, length):
+ length = length + 2 # make sure to read the \r\n terminator
+ # make sure we've read enough data from the socket
+ if length > self.length:
+ self._read_from_socket(length - self.length)
+
+ self._buffer.seek(self.bytes_read)
+ data = self._buffer.read(length)
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ def readline(self):
+ buf = self._buffer
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+ while not data.endswith(SYM_CRLF):
+ # there's more data in the socket that we need
+ self._read_from_socket()
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+
+ self.bytes_read += len(data)
+ return data[:-2]
+
+ def purge(self):
+ self._buffer.seek(0)
+ self._buffer.truncate()
+ self.bytes_written = 0
+ self.bytes_read = 0
+
+ def close(self):
+ self.purge()
+ self._buffer.close()
+
+
class PythonParser(BaseParser):
"Plain Python parsing class"
- MAX_READ_LENGTH = 1000000
encoding = None
def __init__(self):
- self._fp = None
+ self._sock = None
+ self._buffer = None
def __del__(self):
try:
@@ -64,53 +139,25 @@ class PythonParser(BaseParser):
def on_connect(self, connection):
"Called when the socket connects"
- self._fp = connection._sock.makefile('rb')
+ self._sock = connection._sock
+ self._buffer = SocketBuffer(self._sock)
if connection.decode_responses:
self.encoding = connection.encoding
def on_disconnect(self):
"Called when the socket disconnects"
- if self._fp is not None:
- self._fp.close()
- self._fp = None
+ if self._sock is not None:
+ self._sock.close()
+ self._sock = None
+ if self._buffer is not None:
+ self._buffer.close()
+ self._buffer = None
def can_read(self):
- return self._fp and self._fp._rbuf.tell()
-
- def read(self, length=None):
- """
- Read a line from the socket if no length is specified,
- otherwise read ``length`` bytes. Always strip away the newlines.
- """
- try:
- if length is not None:
- bytes_left = length + 2 # read the line ending
- if length > self.MAX_READ_LENGTH:
- # apparently reading more than 1MB or so from a windows
- # socket can cause MemoryErrors. See:
- # https://github.com/andymccurdy/redis-py/issues/205
- # read smaller chunks at a time to work around this
- try:
- buf = BytesIO()
- while bytes_left > 0:
- read_len = min(bytes_left, self.MAX_READ_LENGTH)
- buf.write(self._fp.read(read_len))
- bytes_left -= read_len
- buf.seek(0)
- return buf.read(length)
- finally:
- buf.close()
- return self._fp.read(bytes_left)[:-2]
-
- # no length, read a full line
- return self._fp.readline()[:-2]
- except (socket.error, socket.timeout):
- e = sys.exc_info()[1]
- raise ConnectionError("Error while reading from socket: %s" %
- (e.args,))
+ return self._buffer and bool(self._buffer.length)
def read_response(self):
- response = self.read()
+ response = self._buffer.readline()
if not response:
raise ConnectionError("Socket closed on remote end")
@@ -144,7 +191,7 @@ class PythonParser(BaseParser):
length = int(response)
if length == -1:
return None
- response = self.read(length)
+ response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
@@ -161,7 +208,6 @@ class HiredisParser(BaseParser):
def __init__(self):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
- self._next_response = False
def __del__(self):
try:
@@ -178,10 +224,12 @@ class HiredisParser(BaseParser):
if connection.decode_responses:
kwargs['encoding'] = connection.encoding
self._reader = hiredis.Reader(**kwargs)
+ self._next_response = False
def on_disconnect(self):
self._sock = None
self._reader = None
+ self._next_response = False
def can_read(self):
if not self._reader:
@@ -213,8 +261,8 @@ class HiredisParser(BaseParser):
raise ConnectionError("Socket closed on remote end")
self._reader.feed(buffer)
# proactively, but not conclusively, check if more data is in the
- # buffer. if the data received doesn't end with \n, there's more.
- if not buffer.endswith(SYM_LF):
+ # buffer. if the data received doesn't end with \r\n, there's more.
+ if not buffer.endswith(SYM_CRLF):
continue
response = self._reader.gets()
if isinstance(response, ResponseError):
@@ -256,6 +304,7 @@ class Connection(object):
'port': self.port,
'db': self.db,
}
+ self._connect_callbacks = []
def __repr__(self):
return self.description_format % self._description_args
@@ -266,6 +315,12 @@ class Connection(object):
except Exception:
pass
+ def register_connect_callback(self, callback):
+ self._connect_callbacks.append(callback)
+
+ def clear_connect_callbacks(self):
+ self._connect_callbacks = []
+
def connect(self):
"Connects to the Redis server if not already connected"
if self._sock:
@@ -284,6 +339,11 @@ class Connection(object):
self.disconnect()
raise
+ # run any user callbacks. right now the only internal callback
+ # is for pubsub channel/pattern resubscription
+ for callback in self._connect_callbacks:
+ callback(self)
+
def _connect(self):
"Create a TCP socket connection"
# in 2.6+ try to use IPv6/4 compatibility, else just original code
@@ -361,7 +421,8 @@ class Connection(object):
"Poll the socket to see if there's data that can be read."
sock = self._sock
if not sock:
- return False
+ self.connect()
+ sock = self._sock
return bool(select([sock], [], [], 0)[0]) or self._parser.can_read()
def read_response(self):