summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2014-03-27 19:57:12 -0700
committerAndy McCurdy <andy@andymccurdy.com>2014-03-27 19:57:12 -0700
commit97c50ccebf1498212c8f8e1fa9b87c3b10c48666 (patch)
treea31f0d596cc3a47e9fcfd45a6558fdcebd8701e7
parent81c2402b1438517e53d2d85abf3b5bc326f9bb82 (diff)
downloadredis-py-97c50ccebf1498212c8f8e1fa9b87c3b10c48666.tar.gz
fixed a bunch get_message() bugs, refactored the PythonParser to be saner.
still need more pubsub tests but all this stuff *seems* to be working now
-rw-r--r--redis/client.py33
-rw-r--r--redis/connection.py153
-rw-r--r--tests/test_pubsub.py191
3 files changed, 246 insertions, 131 deletions
diff --git a/redis/client.py b/redis/client.py
index 268a6d5..dd31e3f 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -444,13 +444,13 @@ class StrictRedis(object):
"""
return Lock(self, name, timeout=timeout, sleep=sleep)
- def pubsub(self, shard_hint=None):
+ def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
- return PubSub(self.connection_pool, shard_hint)
+ return PubSub(self.connection_pool, **kwargs)
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def execute_command(self, *args, **options):
@@ -1780,12 +1780,19 @@ class PubSub(object):
def reset(self):
if self.connection:
self.connection.disconnect()
+ self.connection.clear_connect_callbacks()
self.connection_pool.release(self.connection)
self.connection = None
def close(self):
self.reset()
+ def on_connect(self, connection):
+ if self.channels:
+ self.subscribe(**self.channels)
+ if self.patterns:
+ self.psubscribe(**self.patterns)
+
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
@@ -1803,6 +1810,10 @@ class PubSub(object):
'pubsub',
self.shard_hint
)
+ # initially connect here so we don't run our callback the first
+ # time. It's primarily there for reconnection purposes.
+ self.connection.connect()
+ self.connection.register_connect_callback(self.on_connect)
connection = self.connection
self._execute(connection, connection.send_command, *args)
@@ -1816,8 +1827,6 @@ class PubSub(object):
connection.connect()
# resubscribe to all channels and patterns before
# resending the current command
- self.subscribe(**self.channels)
- self.psubscribe(**self.patterns)
return command(*args)
def parse_response(self, block=True):
@@ -1840,9 +1849,11 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- self.patterns.update(dict.fromkeys(args))
- self.patterns.update(kwargs)
- return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns))
+ new_patterns = {}
+ new_patterns.update(dict.fromkeys(args))
+ new_patterns.update(kwargs)
+ self.patterns.update(new_patterns)
+ return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
def punsubscribe(self, *args):
"""
@@ -1869,9 +1880,11 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- self.channels.update(dict.fromkeys(args))
- self.channels.update(kwargs)
- return self.execute_command('SUBSCRIBE', *iterkeys(self.channels))
+ new_channels = {}
+ new_channels.update(dict.fromkeys(args))
+ new_channels.update(kwargs)
+ self.channels.update(new_channels)
+ return self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
def unsubscribe(self, *args):
"""
diff --git a/redis/connection.py b/redis/connection.py
index 80945fa..2fb5663 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -27,7 +27,6 @@ if HIREDIS_AVAILABLE:
SYM_STAR = b('*')
SYM_DOLLAR = b('$')
SYM_CRLF = b('\r\n')
-SYM_LF = b('\n')
SYM_EMPTY = b('')
@@ -48,13 +47,89 @@ class BaseParser(object):
return ResponseError(response)
+class SocketBuffer(object):
+ MAX_READ_LENGTH = 1000000
+
+ def __init__(self, socket):
+ self._sock = socket
+ self._buffer = BytesIO()
+ # number of bytes written to the buffer from the socket
+ self.bytes_written = 0
+ # number of bytes read from the buffer
+ self.bytes_read = 0
+
+ @property
+ def length(self):
+ return self.bytes_written - self.bytes_read
+
+ def _read_from_socket(self, length=None):
+ chunksize = 8192
+ if length is None:
+ length = chunksize
+
+ buf = self._buffer
+ buf.seek(self.bytes_written)
+ marker = 0
+
+ try:
+ while length > marker:
+ data = self._sock.recv(chunksize)
+ buf.write(data)
+ self.bytes_written += len(data)
+ marker += chunksize
+ except (socket.error, socket.timeout):
+ e = sys.exc_info()[1]
+ raise ConnectionError("Error while reading from socket: %s" %
+ (e.args,))
+
+ def read(self, length):
+ length = length + 2 # make sure to read the \r\n terminator
+ # make sure we've read enough data from the socket
+ if length > self.length:
+ self._read_from_socket(length - self.length)
+
+ self._buffer.seek(self.bytes_read)
+ data = self._buffer.read(length)
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ def readline(self):
+ buf = self._buffer
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+ while not data.endswith(SYM_CRLF):
+ # there's more data in the socket that we need
+ self._read_from_socket()
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+
+ self.bytes_read += len(data)
+ return data[:-2]
+
+ def purge(self):
+ self._buffer.seek(0)
+ self._buffer.truncate()
+ self.bytes_written = 0
+ self.bytes_read = 0
+
+ def close(self):
+ self.purge()
+ self._buffer.close()
+
+
class PythonParser(BaseParser):
"Plain Python parsing class"
- MAX_READ_LENGTH = 1000000
encoding = None
def __init__(self):
- self._fp = None
+ self._sock = None
+ self._buffer = None
def __del__(self):
try:
@@ -64,53 +139,25 @@ class PythonParser(BaseParser):
def on_connect(self, connection):
"Called when the socket connects"
- self._fp = connection._sock.makefile('rb')
+ self._sock = connection._sock
+ self._buffer = SocketBuffer(self._sock)
if connection.decode_responses:
self.encoding = connection.encoding
def on_disconnect(self):
"Called when the socket disconnects"
- if self._fp is not None:
- self._fp.close()
- self._fp = None
+ if self._sock is not None:
+ self._sock.close()
+ self._sock = None
+ if self._buffer is not None:
+ self._buffer.close()
+ self._buffer = None
def can_read(self):
- return self._fp and self._fp._rbuf.tell()
-
- def read(self, length=None):
- """
- Read a line from the socket if no length is specified,
- otherwise read ``length`` bytes. Always strip away the newlines.
- """
- try:
- if length is not None:
- bytes_left = length + 2 # read the line ending
- if length > self.MAX_READ_LENGTH:
- # apparently reading more than 1MB or so from a windows
- # socket can cause MemoryErrors. See:
- # https://github.com/andymccurdy/redis-py/issues/205
- # read smaller chunks at a time to work around this
- try:
- buf = BytesIO()
- while bytes_left > 0:
- read_len = min(bytes_left, self.MAX_READ_LENGTH)
- buf.write(self._fp.read(read_len))
- bytes_left -= read_len
- buf.seek(0)
- return buf.read(length)
- finally:
- buf.close()
- return self._fp.read(bytes_left)[:-2]
-
- # no length, read a full line
- return self._fp.readline()[:-2]
- except (socket.error, socket.timeout):
- e = sys.exc_info()[1]
- raise ConnectionError("Error while reading from socket: %s" %
- (e.args,))
+ return self._buffer and bool(self._buffer.length)
def read_response(self):
- response = self.read()
+ response = self._buffer.readline()
if not response:
raise ConnectionError("Socket closed on remote end")
@@ -144,7 +191,7 @@ class PythonParser(BaseParser):
length = int(response)
if length == -1:
return None
- response = self.read(length)
+ response = self._buffer.read(length)
# multi-bulk response
elif byte == '*':
length = int(response)
@@ -161,7 +208,6 @@ class HiredisParser(BaseParser):
def __init__(self):
if not HIREDIS_AVAILABLE:
raise RedisError("Hiredis is not installed")
- self._next_response = False
def __del__(self):
try:
@@ -178,10 +224,12 @@ class HiredisParser(BaseParser):
if connection.decode_responses:
kwargs['encoding'] = connection.encoding
self._reader = hiredis.Reader(**kwargs)
+ self._next_response = False
def on_disconnect(self):
self._sock = None
self._reader = None
+ self._next_response = False
def can_read(self):
if not self._reader:
@@ -213,8 +261,8 @@ class HiredisParser(BaseParser):
raise ConnectionError("Socket closed on remote end")
self._reader.feed(buffer)
# proactively, but not conclusively, check if more data is in the
- # buffer. if the data received doesn't end with \n, there's more.
- if not buffer.endswith(SYM_LF):
+ # buffer. if the data received doesn't end with \r\n, there's more.
+ if not buffer.endswith(SYM_CRLF):
continue
response = self._reader.gets()
if isinstance(response, ResponseError):
@@ -256,6 +304,7 @@ class Connection(object):
'port': self.port,
'db': self.db,
}
+ self._connect_callbacks = []
def __repr__(self):
return self.description_format % self._description_args
@@ -266,6 +315,12 @@ class Connection(object):
except Exception:
pass
+ def register_connect_callback(self, callback):
+ self._connect_callbacks.append(callback)
+
+ def clear_connect_callbacks(self):
+ self._connect_callbacks = []
+
def connect(self):
"Connects to the Redis server if not already connected"
if self._sock:
@@ -284,6 +339,11 @@ class Connection(object):
self.disconnect()
raise
+ # run any user callbacks. right now the only internal callback
+ # is for pubsub channel/pattern resubscription
+ for callback in self._connect_callbacks:
+ callback(self)
+
def _connect(self):
"Create a TCP socket connection"
# in 2.6+ try to use IPv6/4 compatibility, else just original code
@@ -361,7 +421,8 @@ class Connection(object):
"Poll the socket to see if there's data that can be read."
sock = self._sock
if not sock:
- return False
+ self.connect()
+ sock = self._sock
return bool(select([sock], [], [], 0)[0]) or self._parser.can_read()
def read_response(self):
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index 7c81389..52825af 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -1,90 +1,131 @@
from __future__ import with_statement
import pytest
+import time
import redis
-from redis._compat import b, next
from redis.exceptions import ConnectionError
-class TestPubSub(object):
+def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
+ now = time.time()
+ timeout = now + timeout
+ while now < timeout:
+ message = pubsub.get_message(
+ ignore_subscribe_messages=ignore_subscribe_messages)
+ if message is not None:
+ return message
+ time.sleep(0.01)
+ now = time.time()
+ return None
- def test_channel_subscribe(self, r):
+
+def make_message(type, channel, data, pattern=None):
+ return {
+ 'type': type,
+ 'pattern': pattern,
+ 'channel': channel,
+ 'data': data
+ }
+
+
+class TestPubSubSubscribeUnsubscribe(object):
+
+ def test_subscribe_unsubscribe(self, r):
+ p = r.pubsub()
+
+ assert p.subscribe('foo', 'bar') is None
+
+ # should be 2 messages indicating that we've subscribed
+ assert wait_for_message(p) == make_message('subscribe', 'foo', 1)
+ assert wait_for_message(p) == make_message('subscribe', 'bar', 2)
+
+ assert p.unsubscribe('foo', 'bar') is None
+
+ # should be 2 messages indicating that we've unsubscribed
+ assert wait_for_message(p) == make_message('unsubscribe', 'foo', 1)
+ assert wait_for_message(p) == make_message('unsubscribe', 'bar', 0)
+
+ def test_pattern_subscribe_unsubscribe(self, r):
+ p = r.pubsub()
+
+ assert p.psubscribe('f*', 'b*') is None
+
+ # should be 2 messages indicating that we've subscribed
+ assert wait_for_message(p) == make_message('psubscribe', 'f*', 1)
+ assert wait_for_message(p) == make_message('psubscribe', 'b*', 2)
+
+ assert p.punsubscribe('f*', 'b*') is None
+
+ # should be 2 messages indicating that we've unsubscribed
+ assert wait_for_message(p) == make_message('punsubscribe', 'f*', 1)
+ assert wait_for_message(p) == make_message('punsubscribe', 'b*', 0)
+
+ def test_resubscribe_to_channels_on_reconnection(self, r):
+ channels = ['foo', 'bar']
+ p = r.pubsub()
+
+ assert p.subscribe(*channels) is None
+
+ for i, channel in enumerate(channels):
+ i += 1 # enumerate is 0 index, but we want 1 based indexing
+ assert wait_for_message(p) == make_message('subscribe', channel, i)
+
+ # manually disconnect
+ p.connection.disconnect()
+
+ # calling get_message again reconnects and resubscribes
+ for i, channel in enumerate(channels):
+ i += 1 # enumerate is 0 index, but we want 1 based indexing
+ assert wait_for_message(p) == make_message('subscribe', channel, i)
+
+ def test_resubscribe_to_patterns_on_reconnection(self, r):
+ patterns = ['f*', 'b*']
p = r.pubsub()
- # subscribe doesn't return anything
- assert p.subscribe('foo') is None
-
- # send a message
- assert r.publish('foo', 'hello foo') == 1
-
- # there should be now 2 messages in the buffer, a subscribe and the
- # one we just published
- assert next(p.listen()) == \
- {
- 'type': 'subscribe',
- 'pattern': None,
- 'channel': 'foo',
- 'data': 1
- }
-
- assert next(p.listen()) == \
- {
- 'type': 'message',
- 'pattern': None,
- 'channel': 'foo',
- 'data': b('hello foo')
- }
-
- # unsubscribe
- assert p.unsubscribe('foo') is None
-
- # unsubscribe message should be in the buffer
- assert next(p.listen()) == \
- {
- 'type': 'unsubscribe',
- 'pattern': None,
- 'channel': 'foo',
- 'data': 0
- }
-
- def test_pattern_subscribe(self, r):
+ assert p.psubscribe(*patterns) is None
+
+ for i, pattern in enumerate(patterns):
+ i += 1 # enumerate is 0 index, but we want 1 based indexing
+ assert wait_for_message(p) == make_message(
+ 'psubscribe', pattern, i)
+
+ # manually disconnect
+ p.connection.disconnect()
+
+ # calling get_message again reconnects and resubscribes
+ for i, pattern in enumerate(patterns):
+ i += 1 # enumerate is 0 index, but we want 1 based indexing
+ assert wait_for_message(p) == make_message(
+ 'psubscribe', pattern, i)
+
+ def test_ignore_all_subscribe_messages(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+
+ checks = (
+ (p.subscribe, 'foo'),
+ (p.unsubscribe, 'foo'),
+ (p.psubscribe, 'f*'),
+ (p.unsubscribe, 'f*'),
+ )
+
+ for func, channel in checks:
+ assert func(channel) is None
+ assert wait_for_message(p) is None
+
+ def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()
- # psubscribe doesn't return anything
- assert p.psubscribe('f*') is None
-
- # send a message
- assert r.publish('foo', 'hello foo') == 1
-
- # there should be now 2 messages in the buffer, a subscribe and the
- # one we just published
- assert next(p.listen()) == \
- {
- 'type': 'psubscribe',
- 'pattern': None,
- 'channel': 'f*',
- 'data': 1
- }
-
- assert next(p.listen()) == \
- {
- 'type': 'pmessage',
- 'pattern': 'f*',
- 'channel': 'foo',
- 'data': b('hello foo')
- }
-
- # unsubscribe
- assert p.punsubscribe('f*') is None
-
- # unsubscribe message should be in the buffer
- assert next(p.listen()) == \
- {
- 'type': 'punsubscribe',
- 'pattern': None,
- 'channel': 'f*',
- 'data': 0
- }
+ checks = (
+ (p.subscribe, 'foo'),
+ (p.unsubscribe, 'foo'),
+ (p.psubscribe, 'f*'),
+ (p.unsubscribe, 'f*'),
+ )
+
+ for func, channel in checks:
+ assert func(channel) is None
+ message = wait_for_message(p, ignore_subscribe_messages=True)
+ assert message is None
class TestPubSubRedisDown(object):