diff options
author | Marcin Raczyński <niejako@gmail.com> | 2018-12-17 20:51:04 +0100 |
---|---|---|
committer | Marcin Raczyński <niejako@gmail.com> | 2018-12-17 20:51:04 +0100 |
commit | b5285270b07a56ed92744ca072ebe10fd8ce7d32 (patch) | |
tree | 36d3be40c0732eef3496f244acbc6f4a7e2f5de1 /redis/client.py | |
parent | f1a1386b5799f5a2f731e25fe3473cb1118d5099 (diff) | |
download | redis-py-b5285270b07a56ed92744ca072ebe10fd8ce7d32.tar.gz |
Fix #764 - sub-unsub-resub caused PubSub() to forget the channel
Diffstat (limited to 'redis/client.py')
-rwxr-xr-x | redis/client.py | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/redis/client.py b/redis/client.py index 957081c..ec87da8 100755 --- a/redis/client.py +++ b/redis/client.py @@ -2903,7 +2903,9 @@ class PubSub(object): self.connection_pool.release(self.connection) self.connection = None self.channels = {} + self.pending_unsubscribe_channels = set() self.patterns = {} + self.pending_unsubscribe_patterns = set() def close(self): self.reset() @@ -2913,6 +2915,8 @@ class PubSub(object): # NOTE: for python3, we can't pass bytestrings as keyword arguments # so we need to decode channel/pattern names back to unicode strings # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() if self.channels: channels = {} for k, v in iteritems(self.channels): @@ -2999,17 +3003,25 @@ class PubSub(object): # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - self.patterns.update(self._normalize_keys(new_patterns)) + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val def punsubscribe(self, *args): """ - Unsubscribe from the supplied patterns. If empy, unsubscribe from + Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ if args: args = list_or_args(args[0], args[1:]) - return self.execute_command('PUNSUBSCRIBE', *args) + retval = self.execute_command('PUNSUBSCRIBE', *args) + if args: + patterns = self._normalize_keys(dict.fromkeys(args)) + else: + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return retval def subscribe(self, *args, **kwargs): """ @@ -3027,7 +3039,9 @@ class PubSub(object): # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - self.channels.update(self._normalize_keys(new_channels)) + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val def unsubscribe(self, *args): @@ -3037,7 +3051,13 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - return self.execute_command('UNSUBSCRIBE', *args) + retval = self.execute_command('UNSUBSCRIBE', *args) + if args: + channels = self._normalize_keys(dict.fromkeys(args)) + else: + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return retval def listen(self): "Listen for messages on channels this client has been subscribed to" @@ -3094,22 +3114,21 @@ class PubSub(object): 'channel': response[1], 'data': response[2] } - # if this is an unsubscribe message, remove it from memory if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: - subscribed_dict = None if message_type == 'punsubscribe': - subscribed_dict = self.patterns + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) else: - subscribed_dict = self.channels - try: - del subscribed_dict[message['channel']] - except KeyError: - pass + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it - handler = None if message_type == 'pmessage': handler = self.patterns.get(message['pattern'], None) else: |