summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
authorMarcin Raczyński <niejako@gmail.com>2018-12-17 20:51:04 +0100
committerMarcin Raczyński <niejako@gmail.com>2018-12-17 20:51:04 +0100
commitb5285270b07a56ed92744ca072ebe10fd8ce7d32 (patch)
tree36d3be40c0732eef3496f244acbc6f4a7e2f5de1 /redis/client.py
parentf1a1386b5799f5a2f731e25fe3473cb1118d5099 (diff)
downloadredis-py-b5285270b07a56ed92744ca072ebe10fd8ce7d32.tar.gz
Fix #764 - sub-unsub-resub caused PubSub() to forget the channel
Diffstat (limited to 'redis/client.py')
-rwxr-xr-xredis/client.py47
1 files changed, 33 insertions, 14 deletions
diff --git a/redis/client.py b/redis/client.py
index 957081c..ec87da8 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -2903,7 +2903,9 @@ class PubSub(object):
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
+ self.pending_unsubscribe_channels = set()
self.patterns = {}
+ self.pending_unsubscribe_patterns = set()
def close(self):
self.reset()
@@ -2913,6 +2915,8 @@ class PubSub(object):
# NOTE: for python3, we can't pass bytestrings as keyword arguments
# so we need to decode channel/pattern names back to unicode strings
# before passing them to [p]subscribe.
+ self.pending_unsubscribe_channels.clear()
+ self.pending_unsubscribe_patterns.clear()
if self.channels:
channels = {}
for k, v in iteritems(self.channels):
@@ -2999,17 +3003,25 @@ class PubSub(object):
# update the patterns dict AFTER we send the command. we don't want to
# subscribe twice to these patterns, once for the command and again
# for the reconnection.
- self.patterns.update(self._normalize_keys(new_patterns))
+ new_patterns = self._normalize_keys(new_patterns)
+ self.patterns.update(new_patterns)
+ self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val
def punsubscribe(self, *args):
"""
- Unsubscribe from the supplied patterns. If empy, unsubscribe from
+ Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
if args:
args = list_or_args(args[0], args[1:])
- return self.execute_command('PUNSUBSCRIBE', *args)
+ retval = self.execute_command('PUNSUBSCRIBE', *args)
+ if args:
+ patterns = self._normalize_keys(dict.fromkeys(args))
+ else:
+ patterns = self.patterns
+ self.pending_unsubscribe_patterns.update(patterns)
+ return retval
def subscribe(self, *args, **kwargs):
"""
@@ -3027,7 +3039,9 @@ class PubSub(object):
# update the channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
- self.channels.update(self._normalize_keys(new_channels))
+ new_channels = self._normalize_keys(new_channels)
+ self.channels.update(new_channels)
+ self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val
def unsubscribe(self, *args):
@@ -3037,7 +3051,13 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- return self.execute_command('UNSUBSCRIBE', *args)
+ retval = self.execute_command('UNSUBSCRIBE', *args)
+ if args:
+ channels = self._normalize_keys(dict.fromkeys(args))
+ else:
+ channels = self.channels
+ self.pending_unsubscribe_channels.update(channels)
+ return retval
def listen(self):
"Listen for messages on channels this client has been subscribed to"
@@ -3094,22 +3114,21 @@ class PubSub(object):
'channel': response[1],
'data': response[2]
}
-
# if this is an unsubscribe message, remove it from memory
if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
- subscribed_dict = None
if message_type == 'punsubscribe':
- subscribed_dict = self.patterns
+ pattern = response[1]
+ if pattern in self.pending_unsubscribe_patterns:
+ self.pending_unsubscribe_patterns.remove(pattern)
+ self.patterns.pop(pattern, None)
else:
- subscribed_dict = self.channels
- try:
- del subscribed_dict[message['channel']]
- except KeyError:
- pass
+ channel = response[1]
+ if channel in self.pending_unsubscribe_channels:
+ self.pending_unsubscribe_channels.remove(channel)
+ self.channels.pop(channel, None)
if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
- handler = None
if message_type == 'pmessage':
handler = self.patterns.get(message['pattern'], None)
else: