summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2014-04-01 12:28:31 -0700
committerAndy McCurdy <andy@andymccurdy.com>2014-04-01 12:28:31 -0700
commit54e5b062e2eddd674c9a462f763a7854965832d2 (patch)
tree1996b808be7113e7798e1e9fe85437352c8c416d
parent8b7cd8dd781c651877a14e9505f20e5553572f59 (diff)
downloadredis-py-54e5b062e2eddd674c9a462f763a7854965832d2.tar.gz
automatic message decoding if decode_responses=True. bugfixes, tests.
-rw-r--r--redis/client.py93
-rw-r--r--redis/connection.py1
-rw-r--r--tests/test_pubsub.py254
3 files changed, 289 insertions, 59 deletions
diff --git a/redis/client.py b/redis/client.py
index 7837dd3..352be7d 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -1762,6 +1762,12 @@ class PubSub(object):
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
+ # we need to know the encoding options for this connection in order
+ # to lookup channel and pattern names for callback handlers.
+ connection_kwargs = self.connection_pool.connection_kwargs
+ self.encoding = connection_kwargs['encoding']
+ self.encoding_errors = connection_kwargs['encoding_errors']
+ self.decode_responses = connection_kwargs['decode_responses']
self.reset()
def __del__(self):
@@ -1786,10 +1792,35 @@ class PubSub(object):
self.reset()
def on_connect(self, connection):
+ "Re-subscribe to any channels and patterns previously subscribed to"
+ # NOTE: for python3, we can't pass bytestrings as keyword arguments
+ # so we need to decode channel/pattern names back to unicode strings
+ # before passing them to [p]subscribe.
if self.channels:
- self.subscribe(**self.channels)
+ channels = {}
+ for k, v in iteritems(self.channels):
+ if not self.decode_responses:
+ k = k.decode(self.encoding, self.encoding_errors)
+ channels[k] = v
+ self.subscribe(**channels)
if self.patterns:
- self.psubscribe(**self.patterns)
+ patterns = {}
+ for k, v in iteritems(self.patterns):
+ if not self.decode_responses:
+ k = k.decode(self.encoding, self.encoding_errors)
+ patterns[k] = v
+ self.psubscribe(**patterns)
+
+ def encode(self, value):
+ """
+ Encode the value so that it's identical to what we'll
+ read off the connection
+ """
+ if self.decode_responses and isinstance(value, bytes):
+ value = value.decode(self.encoding, self.encoding_errors)
+ elif not self.decode_responses and isinstance(value, unicode):
+ value = value.encode(self.encoding, self.encoding_errors)
+ return value
@property
def subscribed(self):
@@ -1808,10 +1839,8 @@ class PubSub(object):
'pubsub',
self.shard_hint
)
- # initially connect here so we don't run our callback the first
- # time. If we did, it would dupe the subscriptions, once from the
- # callback and a second time from the actual command invocation
- self.connection.connect()
+ # register a callback that re-subscribes to any channels we
+ # were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
self._execute(connection, connection.send_command, *args)
@@ -1847,10 +1876,15 @@ class PubSub(object):
if args:
args = list_or_args(args[0], args[1:])
new_patterns = {}
- new_patterns.update(dict.fromkeys(args))
- new_patterns.update(kwargs)
+ new_patterns.update(dict.fromkeys(imap(self.encode, args)))
+ for pattern, handler in iteritems(kwargs):
+ new_patterns[self.encode(pattern)] = handler
+ ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
+ # update the patterns dict AFTER we send the command. we don't want to
+ # subscribe twice to these patterns, once for the command and again
+ # for the reconnection.
self.patterns.update(new_patterns)
- return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
+ return ret_val
def punsubscribe(self, *args):
"""
@@ -1872,10 +1906,15 @@ class PubSub(object):
if args:
args = list_or_args(args[0], args[1:])
new_channels = {}
- new_channels.update(dict.fromkeys(args))
- new_channels.update(kwargs)
+ new_channels.update(dict.fromkeys(imap(self.encode, args)))
+ for channel, handler in iteritems(kwargs):
+ new_channels[self.encode(channel)] = handler
+ ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
+ # update the channels dict AFTER we send the command. we don't want to
+ # subscribe twice to these channels, once for the command and again
+ # for the reconnection.
self.channels.update(new_channels)
- return self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
+ return ret_val
def unsubscribe(self, *args):
"""
@@ -1910,18 +1949,30 @@ class PubSub(object):
if message_type == 'pmessage':
message = {
'type': message_type,
- 'pattern': nativestr(response[1]),
- 'channel': nativestr(response[2]),
+ 'pattern': response[1],
+ 'channel': response[2],
'data': response[3]
}
else:
message = {
'type': message_type,
'pattern': None,
- 'channel': nativestr(response[1]),
+ 'channel': response[1],
'data': response[2]
}
+ # if this is an unsubscribe message, remove it from memory
+ if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
+ subscribed_dict = None
+ if message_type == 'punsubscribe':
+ subscribed_dict = self.patterns
+ else:
+ subscribed_dict = self.channels
+ try:
+ del subscribed_dict[message['channel']]
+ except KeyError:
+ pass
+
if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
handler = None
@@ -1938,18 +1989,6 @@ class PubSub(object):
if ignore_subscribe_messages or self.ignore_subscribe_messages:
return None
- # if this is an unsubscribe message, remove it from memory
- if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
- subscribed_dict = None
- if message_type == 'punsubscribe':
- subscribed_dict = self.patterns
- else:
- subscribed_dict = self.channels
- try:
- del subscribed_dict[message['channel']]
- except KeyError:
- pass
-
return message
def run_in_thread(self, sleep_time=0):
diff --git a/redis/connection.py b/redis/connection.py
index e00f6cc..45f3562 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -154,6 +154,7 @@ class PythonParser(BaseParser):
if self._buffer is not None:
self._buffer.close()
self._buffer = None
+ self.encoding = None
def can_read(self):
return self._buffer and bool(self._buffer.length)
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index 1a3f460..5486b75 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -4,6 +4,9 @@ import time
import redis
from redis.exceptions import ConnectionError
+from redis._compat import basestring, u, unichr
+
+from .conftest import r as _redis_client
def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
@@ -22,9 +25,9 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
def make_message(type, channel, data, pattern=None):
return {
'type': type,
- 'pattern': pattern,
- 'channel': channel,
- 'data': data
+ 'pattern': pattern and pattern.encode('utf-8') or None,
+ 'channel': channel.encode('utf-8'),
+ 'data': data.encode('utf-8') if isinstance(data, basestring) else data
}
@@ -36,7 +39,7 @@ def make_subscribe_test_data(pubsub, type):
'unsub_type': 'unsubscribe',
'sub_func': pubsub.subscribe,
'unsub_func': pubsub.unsubscribe,
- 'keys': ['foo', 'bar']
+ 'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')]
}
elif type == 'pattern':
return {
@@ -45,7 +48,7 @@ def make_subscribe_test_data(pubsub, type):
'unsub_type': 'punsubscribe',
'sub_func': pubsub.psubscribe,
'unsub_func': pubsub.punsubscribe,
- 'keys': ['f*', 'b*']
+ 'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')]
}
assert False, 'invalid subscribe type: %s' % type
@@ -54,19 +57,21 @@ class TestPubSubSubscribeUnsubscribe(object):
def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func,
unsub_func, keys):
- assert sub_func(keys[0]) is None
- assert sub_func(keys[1]) is None
+ for key in keys:
+ assert sub_func(key) is None
- # should be 2 messages indicating that we've subscribed
- assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
- assert wait_for_message(p) == make_message(sub_type, keys[1], 2)
+ # should be a message for each channel/pattern we just subscribed to
+ for i, key in enumerate(keys):
+ assert wait_for_message(p) == make_message(sub_type, key, i + 1)
- assert unsub_func(keys[0]) is None
- assert unsub_func(keys[1]) is None
+ for key in keys:
+ assert unsub_func(key) is None
- # should be 2 messages indicating that we've unsubscribed
- assert wait_for_message(p) == make_message(unsub_type, keys[0], 1)
- assert wait_for_message(p) == make_message(unsub_type, keys[1], 0)
+ # should be a message for each channel/pattern we just unsubscribed
+ # from
+ for i, key in enumerate(keys):
+ i = len(keys) - 1 - i
+ assert wait_for_message(p) == make_message(unsub_type, key, i)
def test_channel_subscribe_unsubscribe(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
@@ -78,28 +83,36 @@ class TestPubSubSubscribeUnsubscribe(object):
def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type,
sub_func, unsub_func, keys):
- assert sub_func(keys[0]) is None
- assert sub_func(keys[1]) is None
- # should be 2 messages indicating that we've subscribed
- assert wait_for_message(p) == make_message(sub_type, keys[0], 1)
- assert wait_for_message(p) == make_message(sub_type, keys[1], 2)
+ for key in keys:
+ assert sub_func(key) is None
+
+ # should be a message for each channel/pattern we just subscribed to
+ for i, key in enumerate(keys):
+ assert wait_for_message(p) == make_message(sub_type, key, i + 1)
# manually disconnect
p.connection.disconnect()
# calling get_message again reconnects and resubscribes
# note, we may not re-subscribe to channels in exactly the same order
- message1 = wait_for_message(p)
- message2 = wait_for_message(p)
-
- assert message1['type'] == sub_type
- assert message1['channel'] in keys
- assert message1['data'] == 1
- assert message2['type'] == sub_type
- assert message2['channel'] in keys
- assert message2['data'] == 2
- assert message1['channel'] != message2['channel']
+ # so we have to do some extra checks to make sure we got them all
+ messages = []
+ for i in range(len(keys)):
+ messages.append(wait_for_message(p))
+
+ unique_channels = set()
+ assert len(messages) == len(keys)
+ for i, message in enumerate(messages):
+ assert message['type'] == sub_type
+ assert message['data'] == i + 1
+ assert isinstance(message['channel'], bytes)
+ channel = message['channel'].decode('utf-8')
+ unique_channels.add(channel)
+
+ assert len(unique_channels) == len(keys)
+ for channel in unique_channels:
+ assert channel in keys
def test_resubscribe_to_channels_on_reconnection(self, r):
kwargs = make_subscribe_test_data(r.pubsub(), 'channel')
@@ -170,12 +183,15 @@ class TestPubSubSubscribeUnsubscribe(object):
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
- (p.unsubscribe, 'f*'),
+ (p.punsubscribe, 'f*'),
)
+ assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
+ assert p.subscribed is True
assert wait_for_message(p) is None
+ assert p.subscribed is False
def test_ignore_individual_subscribe_messages(self, r):
p = r.pubsub()
@@ -184,13 +200,187 @@ class TestPubSubSubscribeUnsubscribe(object):
(p.subscribe, 'foo'),
(p.unsubscribe, 'foo'),
(p.psubscribe, 'f*'),
- (p.unsubscribe, 'f*'),
+ (p.punsubscribe, 'f*'),
)
+ assert p.subscribed is False
for func, channel in checks:
assert func(channel) is None
+ assert p.subscribed is True
message = wait_for_message(p, ignore_subscribe_messages=True)
assert message is None
+ assert p.subscribed is False
+
+
+class TestPubSubMessages(object):
+ def setup_method(self, method):
+ self.message = None
+
+ def message_handler(self, message):
+ self.message = message
+
+ def test_published_message_to_channel(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.subscribe('foo')
+ assert r.publish('foo', 'test message') == 1
+
+ message = wait_for_message(p)
+ assert isinstance(message, dict)
+ assert message == make_message('message', 'foo', 'test message')
+
+ def test_published_message_to_pattern(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.subscribe('foo')
+ p.psubscribe('f*')
+ # 1 to pattern, 1 to channel
+ assert r.publish('foo', 'test message') == 2
+
+ message1 = wait_for_message(p)
+ message2 = wait_for_message(p)
+ assert isinstance(message1, dict)
+ assert isinstance(message2, dict)
+
+ expected = [
+ make_message('message', 'foo', 'test message'),
+ make_message('pmessage', 'foo', 'test message', pattern='f*')
+ ]
+
+ assert message1 in expected
+ assert message2 in expected
+ assert message1 != message2
+
+ def test_channel_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.subscribe(foo=self.message_handler)
+ assert r.publish('foo', 'test message') == 1
+ assert wait_for_message(p) is None
+ assert self.message == make_message('message', 'foo', 'test message')
+
+ def test_pattern_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.psubscribe(**{'f*': self.message_handler})
+ assert r.publish('foo', 'test message') == 1
+ assert wait_for_message(p) is None
+ assert self.message == make_message('pmessage', 'foo', 'test message',
+ pattern='f*')
+
+ def test_unicode_channel_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ channel = u('uni') + unichr(4456) + u('code')
+ channels = {channel: self.message_handler}
+ p.subscribe(**channels)
+ assert r.publish(channel, 'test message') == 1
+ assert wait_for_message(p) is None
+ assert self.message == make_message('message', channel, 'test message')
+
+ def test_unicode_pattern_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ pattern = u('uni') + unichr(4456) + u('*')
+ channel = u('uni') + unichr(4456) + u('code')
+ p.psubscribe(**{pattern: self.message_handler})
+ assert r.publish(channel, 'test message') == 1
+ assert wait_for_message(p) is None
+ assert self.message == make_message('pmessage', channel,
+ 'test message', pattern=pattern)
+
+
+class TestPubSubAutoDecoding(object):
+ "These tests only validate that we get unicode values back"
+
+ channel = u('uni') + unichr(4456) + u('code')
+ pattern = u('uni') + unichr(4456) + u('*')
+ data = u('abc') + unichr(4458) + u('123')
+
+ def make_message(self, type, channel, data, pattern=None):
+ return {
+ 'type': type,
+ 'channel': channel,
+ 'pattern': pattern,
+ 'data': data
+ }
+
+ def setup_method(self, method):
+ self.message = None
+
+ def message_handler(self, message):
+ self.message = message
+
+ @pytest.fixture()
+ def r(self, request):
+ return _redis_client(request=request, decode_responses=True)
+
+ def test_channel_subscribe_unsubscribe(self, r):
+ p = r.pubsub()
+ p.subscribe(self.channel)
+ assert wait_for_message(p) == self.make_message('subscribe',
+ self.channel, 1)
+
+ p.unsubscribe(self.channel)
+ assert wait_for_message(p) == self.make_message('unsubscribe',
+ self.channel, 0)
+
+ def test_pattern_subscribe_unsubscribe(self, r):
+ p = r.pubsub()
+ p.psubscribe(self.pattern)
+ assert wait_for_message(p) == self.make_message('psubscribe',
+ self.pattern, 1)
+
+ p.punsubscribe(self.pattern)
+ assert wait_for_message(p) == self.make_message('punsubscribe',
+ self.pattern, 0)
+
+ def test_channel_publish(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.subscribe(self.channel)
+ r.publish(self.channel, self.data)
+ assert wait_for_message(p) == self.make_message('message',
+ self.channel,
+ self.data)
+
+ def test_pattern_publish(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.psubscribe(self.pattern)
+ r.publish(self.channel, self.data)
+ assert wait_for_message(p) == self.make_message('pmessage',
+ self.channel,
+ self.data,
+ pattern=self.pattern)
+
+ def test_channel_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.subscribe(**{self.channel: self.message_handler})
+ r.publish(self.channel, self.data)
+ assert wait_for_message(p) is None
+ assert self.message == self.make_message('message', self.channel,
+ self.data)
+
+ # test that we reconnected to the correct channel
+ p.connection.disconnect()
+ assert wait_for_message(p) is None # should reconnect
+ new_data = self.data + u('new data')
+ r.publish(self.channel, new_data)
+ assert wait_for_message(p) is None
+ assert self.message == self.make_message('message', self.channel,
+ new_data)
+
+ def test_pattern_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ p.psubscribe(**{self.pattern: self.message_handler})
+ r.publish(self.channel, self.data)
+ assert wait_for_message(p) is None
+ assert self.message == self.make_message('pmessage', self.channel,
+ self.data,
+ pattern=self.pattern)
+
+ # test that we reconnected to the correct pattern
+ p.connection.disconnect()
+ assert wait_for_message(p) is None # should reconnect
+ new_data = self.data + u('new data')
+ r.publish(self.channel, new_data)
+ assert wait_for_message(p) is None
+ assert self.message == self.make_message('pmessage', self.channel,
+ new_data,
+ pattern=self.pattern)
class TestPubSubRedisDown(object):