summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/client.py')
-rw-r--r--redis/client.py33
1 files changed, 23 insertions, 10 deletions
diff --git a/redis/client.py b/redis/client.py
index 268a6d5..dd31e3f 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -444,13 +444,13 @@ class StrictRedis(object):
"""
return Lock(self, name, timeout=timeout, sleep=sleep)
- def pubsub(self, shard_hint=None):
+ def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
- return PubSub(self.connection_pool, shard_hint)
+ return PubSub(self.connection_pool, **kwargs)
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def execute_command(self, *args, **options):
@@ -1780,12 +1780,19 @@ class PubSub(object):
def reset(self):
if self.connection:
self.connection.disconnect()
+ self.connection.clear_connect_callbacks()
self.connection_pool.release(self.connection)
self.connection = None
def close(self):
self.reset()
+ def on_connect(self, connection):
+ if self.channels:
+ self.subscribe(**self.channels)
+ if self.patterns:
+ self.psubscribe(**self.patterns)
+
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
@@ -1803,6 +1810,10 @@ class PubSub(object):
'pubsub',
self.shard_hint
)
+ # initially connect here so we don't run our callback the first
+ # time. It's primarily there for reconnection purposes.
+ self.connection.connect()
+ self.connection.register_connect_callback(self.on_connect)
connection = self.connection
self._execute(connection, connection.send_command, *args)
@@ -1816,8 +1827,6 @@ class PubSub(object):
connection.connect()
# resubscribe to all channels and patterns before
# resending the current command
- self.subscribe(**self.channels)
- self.psubscribe(**self.patterns)
return command(*args)
def parse_response(self, block=True):
@@ -1840,9 +1849,11 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- self.patterns.update(dict.fromkeys(args))
- self.patterns.update(kwargs)
- return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns))
+ new_patterns = {}
+ new_patterns.update(dict.fromkeys(args))
+ new_patterns.update(kwargs)
+ self.patterns.update(new_patterns)
+ return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns))
def punsubscribe(self, *args):
"""
@@ -1869,9 +1880,11 @@ class PubSub(object):
"""
if args:
args = list_or_args(args[0], args[1:])
- self.channels.update(dict.fromkeys(args))
- self.channels.update(kwargs)
- return self.execute_command('SUBSCRIBE', *iterkeys(self.channels))
+ new_channels = {}
+ new_channels.update(dict.fromkeys(args))
+ new_channels.update(kwargs)
+ self.channels.update(new_channels)
+ return self.execute_command('SUBSCRIBE', *iterkeys(new_channels))
def unsubscribe(self, *args):
"""