diff options
Diffstat (limited to 'redis/client.py')
-rw-r--r-- | redis/client.py | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/redis/client.py b/redis/client.py index 268a6d5..dd31e3f 100644 --- a/redis/client.py +++ b/redis/client.py @@ -444,13 +444,13 @@ class StrictRedis(object): """ return Lock(self, name, timeout=timeout, sleep=sleep) - def pubsub(self, shard_hint=None): + def pubsub(self, **kwargs): """ Return a Publish/Subscribe object. With this object, you can subscribe to channels and listen for messages that get published to them. """ - return PubSub(self.connection_pool, shard_hint) + return PubSub(self.connection_pool, **kwargs) #### COMMAND EXECUTION AND PROTOCOL PARSING #### def execute_command(self, *args, **options): @@ -1780,12 +1780,19 @@ class PubSub(object): def reset(self): if self.connection: self.connection.disconnect() + self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None def close(self): self.reset() + def on_connect(self, connection): + if self.channels: + self.subscribe(**self.channels) + if self.patterns: + self.psubscribe(**self.patterns) + @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -1803,6 +1810,10 @@ class PubSub(object): 'pubsub', self.shard_hint ) + # initially connect here so we don't run our callback the first + # time. It's primarily there for reconnection purposes. + self.connection.connect() + self.connection.register_connect_callback(self.on_connect) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -1816,8 +1827,6 @@ class PubSub(object): connection.connect() # resubscribe to all channels and patterns before # resending the current command - self.subscribe(**self.channels) - self.psubscribe(**self.patterns) return command(*args) def parse_response(self, block=True): @@ -1840,9 +1849,11 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - self.patterns.update(dict.fromkeys(args)) - self.patterns.update(kwargs) - return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns)) + new_patterns = {} + new_patterns.update(dict.fromkeys(args)) + new_patterns.update(kwargs) + self.patterns.update(new_patterns) + return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) def punsubscribe(self, *args): """ @@ -1869,9 +1880,11 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - self.channels.update(dict.fromkeys(args)) - self.channels.update(kwargs) - return self.execute_command('SUBSCRIBE', *iterkeys(self.channels)) + new_channels = {} + new_channels.update(dict.fromkeys(args)) + new_channels.update(kwargs) + self.channels.update(new_channels) + return self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) def unsubscribe(self, *args): """ |