diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-04-01 12:28:31 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-04-01 12:28:31 -0700 |
commit | 54e5b062e2eddd674c9a462f763a7854965832d2 (patch) | |
tree | 1996b808be7113e7798e1e9fe85437352c8c416d | |
parent | 8b7cd8dd781c651877a14e9505f20e5553572f59 (diff) | |
download | redis-py-54e5b062e2eddd674c9a462f763a7854965832d2.tar.gz |
automatic message decoding if decode_responses=True. bugfixes, tests.
-rw-r--r-- | redis/client.py | 93 | ||||
-rw-r--r-- | redis/connection.py | 1 | ||||
-rw-r--r-- | tests/test_pubsub.py | 254 |
3 files changed, 289 insertions, 59 deletions
diff --git a/redis/client.py b/redis/client.py index 7837dd3..352be7d 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1762,6 +1762,12 @@ class PubSub(object): self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + connection_kwargs = self.connection_pool.connection_kwargs + self.encoding = connection_kwargs['encoding'] + self.encoding_errors = connection_kwargs['encoding_errors'] + self.decode_responses = connection_kwargs['decode_responses'] self.reset() def __del__(self): @@ -1786,10 +1792,35 @@ class PubSub(object): self.reset() def on_connect(self, connection): + "Re-subscribe to any channels and patterns previously subscribed to" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. if self.channels: - self.subscribe(**self.channels) + channels = {} + for k, v in iteritems(self.channels): + if not self.decode_responses: + k = k.decode(self.encoding, self.encoding_errors) + channels[k] = v + self.subscribe(**channels) if self.patterns: - self.psubscribe(**self.patterns) + patterns = {} + for k, v in iteritems(self.patterns): + if not self.decode_responses: + k = k.decode(self.encoding, self.encoding_errors) + patterns[k] = v + self.psubscribe(**patterns) + + def encode(self, value): + """ + Encode the value so that it's identical to what we'll + read off the connection + """ + if self.decode_responses and isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + elif not self.decode_responses and isinstance(value, unicode): + value = value.encode(self.encoding, self.encoding_errors) + return value @property def subscribed(self): @@ -1808,10 +1839,8 @@ class PubSub(object): 'pubsub', self.shard_hint ) - # initially connect here so we don't run our callback the first - # time. If we did, it would dupe the subscriptions, once from the - # callback and a second time from the actual command invocation - self.connection.connect() + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -1847,10 +1876,15 @@ class PubSub(object): if args: args = list_or_args(args[0], args[1:]) new_patterns = {} - new_patterns.update(dict.fromkeys(args)) - new_patterns.update(kwargs) + new_patterns.update(dict.fromkeys(imap(self.encode, args))) + for pattern, handler in iteritems(kwargs): + new_patterns[self.encode(pattern)] = handler + ret_val = self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. self.patterns.update(new_patterns) - return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) + return ret_val def punsubscribe(self, *args): """ @@ -1872,10 +1906,15 @@ class PubSub(object): if args: args = list_or_args(args[0], args[1:]) new_channels = {} - new_channels.update(dict.fromkeys(args)) - new_channels.update(kwargs) + new_channels.update(dict.fromkeys(imap(self.encode, args))) + for channel, handler in iteritems(kwargs): + new_channels[self.encode(channel)] = handler + ret_val = self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. self.channels.update(new_channels) - return self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) + return ret_val def unsubscribe(self, *args): """ @@ -1910,18 +1949,30 @@ class PubSub(object): if message_type == 'pmessage': message = { 'type': message_type, - 'pattern': nativestr(response[1]), - 'channel': nativestr(response[2]), + 'pattern': response[1], + 'channel': response[2], 'data': response[3] } else: message = { 'type': message_type, 'pattern': None, - 'channel': nativestr(response[1]), + 'channel': response[1], 'data': response[2] } + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + subscribed_dict = None + if message_type == 'punsubscribe': + subscribed_dict = self.patterns + else: + subscribed_dict = self.channels + try: + del subscribed_dict[message['channel']] + except KeyError: + pass + if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it handler = None @@ -1938,18 +1989,6 @@ class PubSub(object): if ignore_subscribe_messages or self.ignore_subscribe_messages: return None - # if this is an unsubscribe message, remove it from memory - if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: - subscribed_dict = None - if message_type == 'punsubscribe': - subscribed_dict = self.patterns - else: - subscribed_dict = self.channels - try: - del subscribed_dict[message['channel']] - except KeyError: - pass - return message def run_in_thread(self, sleep_time=0): diff --git a/redis/connection.py b/redis/connection.py index e00f6cc..45f3562 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -154,6 +154,7 @@ class PythonParser(BaseParser): if self._buffer is not None: self._buffer.close() self._buffer = None + self.encoding = None def can_read(self): return self._buffer and bool(self._buffer.length) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 1a3f460..5486b75 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -4,6 +4,9 @@ import time import redis from redis.exceptions import ConnectionError +from redis._compat import basestring, u, unichr + +from .conftest import r as _redis_client def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -22,9 +25,9 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): def make_message(type, channel, data, pattern=None): return { 'type': type, - 'pattern': pattern, - 'channel': channel, - 'data': data + 'pattern': pattern and pattern.encode('utf-8') or None, + 'channel': channel.encode('utf-8'), + 'data': data.encode('utf-8') if isinstance(data, basestring) else data } @@ -36,7 +39,7 @@ def make_subscribe_test_data(pubsub, type): 'unsub_type': 'unsubscribe', 'sub_func': pubsub.subscribe, 'unsub_func': pubsub.unsubscribe, - 'keys': ['foo', 'bar'] + 'keys': ['foo', 'bar', u('uni') + unichr(4456) + u('code')] } elif type == 'pattern': return { @@ -45,7 +48,7 @@ def make_subscribe_test_data(pubsub, type): 'unsub_type': 'punsubscribe', 'sub_func': pubsub.psubscribe, 'unsub_func': pubsub.punsubscribe, - 'keys': ['f*', 'b*'] + 'keys': ['f*', 'b*', u('uni') + unichr(4456) + u('*')] } assert False, 'invalid subscribe type: %s' % type @@ -54,19 +57,21 @@ class TestPubSubSubscribeUnsubscribe(object): def _test_subscribe_unsubscribe(self, p, sub_type, unsub_type, sub_func, unsub_func, keys): - assert sub_func(keys[0]) is None - assert sub_func(keys[1]) is None + for key in keys: + assert sub_func(key) is None - # should be 2 messages indicating that we've subscribed - assert wait_for_message(p) == make_message(sub_type, keys[0], 1) - assert wait_for_message(p) == make_message(sub_type, keys[1], 2) + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert wait_for_message(p) == make_message(sub_type, key, i + 1) - assert unsub_func(keys[0]) is None - assert unsub_func(keys[1]) is None + for key in keys: + assert unsub_func(key) is None - # should be 2 messages indicating that we've unsubscribed - assert wait_for_message(p) == make_message(unsub_type, keys[0], 1) - assert wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert wait_for_message(p) == make_message(unsub_type, key, i) def test_channel_subscribe_unsubscribe(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') @@ -78,28 +83,36 @@ class TestPubSubSubscribeUnsubscribe(object): def _test_resubscribe_on_reconnection(self, p, sub_type, unsub_type, sub_func, unsub_func, keys): - assert sub_func(keys[0]) is None - assert sub_func(keys[1]) is None - # should be 2 messages indicating that we've subscribed - assert wait_for_message(p) == make_message(sub_type, keys[0], 1) - assert wait_for_message(p) == make_message(sub_type, keys[1], 2) + for key in keys: + assert sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert wait_for_message(p) == make_message(sub_type, key, i + 1) # manually disconnect p.connection.disconnect() # calling get_message again reconnects and resubscribes # note, we may not re-subscribe to channels in exactly the same order - message1 = wait_for_message(p) - message2 = wait_for_message(p) - - assert message1['type'] == sub_type - assert message1['channel'] in keys - assert message1['data'] == 1 - assert message2['type'] == sub_type - assert message2['channel'] in keys - assert message2['data'] == 2 - assert message1['channel'] != message2['channel'] + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message['type'] == sub_type + assert message['data'] == i + 1 + assert isinstance(message['channel'], bytes) + channel = message['channel'].decode('utf-8') + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys def test_resubscribe_to_channels_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') @@ -170,12 +183,15 @@ class TestPubSubSubscribeUnsubscribe(object): (p.subscribe, 'foo'), (p.unsubscribe, 'foo'), (p.psubscribe, 'f*'), - (p.unsubscribe, 'f*'), + (p.punsubscribe, 'f*'), ) + assert p.subscribed is False for func, channel in checks: assert func(channel) is None + assert p.subscribed is True assert wait_for_message(p) is None + assert p.subscribed is False def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() @@ -184,13 +200,187 @@ class TestPubSubSubscribeUnsubscribe(object): (p.subscribe, 'foo'), (p.unsubscribe, 'foo'), (p.psubscribe, 'f*'), - (p.unsubscribe, 'f*'), + (p.punsubscribe, 'f*'), ) + assert p.subscribed is False for func, channel in checks: assert func(channel) is None + assert p.subscribed is True message = wait_for_message(p, ignore_subscribe_messages=True) assert message is None + assert p.subscribed is False + + +class TestPubSubMessages(object): + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + def test_published_message_to_channel(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + assert r.publish('foo', 'test message') == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message('message', 'foo', 'test message') + + def test_published_message_to_pattern(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe('foo') + p.psubscribe('f*') + # 1 to pattern, 1 to channel + assert r.publish('foo', 'test message') == 2 + + message1 = wait_for_message(p) + message2 = wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message('message', 'foo', 'test message'), + make_message('pmessage', 'foo', 'test message', pattern='f*') + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(foo=self.message_handler) + assert r.publish('foo', 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('message', 'foo', 'test message') + + def test_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(**{'f*': self.message_handler}) + assert r.publish('foo', 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('pmessage', 'foo', 'test message', + pattern='f*') + + def test_unicode_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = u('uni') + unichr(4456) + u('code') + channels = {channel: self.message_handler} + p.subscribe(**channels) + assert r.publish(channel, 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('message', channel, 'test message') + + def test_unicode_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = u('uni') + unichr(4456) + u('*') + channel = u('uni') + unichr(4456) + u('code') + p.psubscribe(**{pattern: self.message_handler}) + assert r.publish(channel, 'test message') == 1 + assert wait_for_message(p) is None + assert self.message == make_message('pmessage', channel, + 'test message', pattern=pattern) + + +class TestPubSubAutoDecoding(object): + "These tests only validate that we get unicode values back" + + channel = u('uni') + unichr(4456) + u('code') + pattern = u('uni') + unichr(4456) + u('*') + data = u('abc') + unichr(4458) + u('123') + + def make_message(self, type, channel, data, pattern=None): + return { + 'type': type, + 'channel': channel, + 'pattern': pattern, + 'data': data + } + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest.fixture() + def r(self, request): + return _redis_client(request=request, decode_responses=True) + + def test_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.subscribe(self.channel) + assert wait_for_message(p) == self.make_message('subscribe', + self.channel, 1) + + p.unsubscribe(self.channel) + assert wait_for_message(p) == self.make_message('unsubscribe', + self.channel, 0) + + def test_pattern_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.psubscribe(self.pattern) + assert wait_for_message(p) == self.make_message('psubscribe', + self.pattern, 1) + + p.punsubscribe(self.pattern) + assert wait_for_message(p) == self.make_message('punsubscribe', + self.pattern, 0) + + def test_channel_publish(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(self.channel) + r.publish(self.channel, self.data) + assert wait_for_message(p) == self.make_message('message', + self.channel, + self.data) + + def test_pattern_publish(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(self.pattern) + r.publish(self.channel, self.data) + assert wait_for_message(p) == self.make_message('pmessage', + self.channel, + self.data, + pattern=self.pattern) + + def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(**{self.channel: self.message_handler}) + r.publish(self.channel, self.data) + assert wait_for_message(p) is None + assert self.message == self.make_message('message', self.channel, + self.data) + + # test that we reconnected to the correct channel + p.connection.disconnect() + assert wait_for_message(p) is None # should reconnect + new_data = self.data + u('new data') + r.publish(self.channel, new_data) + assert wait_for_message(p) is None + assert self.message == self.make_message('message', self.channel, + new_data) + + def test_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.psubscribe(**{self.pattern: self.message_handler}) + r.publish(self.channel, self.data) + assert wait_for_message(p) is None + assert self.message == self.make_message('pmessage', self.channel, + self.data, + pattern=self.pattern) + + # test that we reconnected to the correct pattern + p.connection.disconnect() + assert wait_for_message(p) is None # should reconnect + new_data = self.data + u('new data') + r.publish(self.channel, new_data) + assert wait_for_message(p) is None + assert self.message == self.make_message('pmessage', self.channel, + new_data, + pattern=self.pattern) class TestPubSubRedisDown(object): |