diff options
Diffstat (limited to 'redis/client.py')
-rwxr-xr-x | redis/client.py | 53 |
1 files changed, 42 insertions, 11 deletions
diff --git a/redis/client.py b/redis/client.py index a38098f..80ffb68 100755 --- a/redis/client.py +++ b/redis/client.py @@ -648,7 +648,8 @@ class Redis(object): decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, - max_connections=None, single_connection_client=False): + max_connections=None, single_connection_client=False, + health_check_interval=0): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -667,7 +668,8 @@ class Redis(object): 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, 'retry_on_timeout': retry_on_timeout, - 'max_connections': max_connections + 'max_connections': max_connections, + 'health_check_interval': health_check_interval, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -3053,6 +3055,7 @@ class PubSub(object): """ PUBLISH_MESSAGE_TYPES = ('message', 'pmessage') UNSUBSCRIBE_MESSAGE_TYPES = ('unsubscribe', 'punsubscribe') + HEALTH_CHECK_MESSAGE = 'redis-py-health-check' def __init__(self, connection_pool, shard_hint=None, ignore_subscribe_messages=False): @@ -3063,6 +3066,13 @@ class PubSub(object): # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE] + else: + self.health_check_response = [ + b'pong', + self.encoder.encode(self.HEALTH_CHECK_MESSAGE) + ] self.reset() def __del__(self): @@ -3111,7 +3121,7 @@ class PubSub(object): "Indicates if there are subscriptions to any channels or patterns" return bool(self.channels or self.patterns) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): "Execute a publish/subscribe command" # NOTE: don't parse the response in this function -- it could pull a @@ -3127,11 +3137,12 @@ class PubSub(object): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - self._execute(connection, connection.send_command, *args) + kwargs = {'check_health': not self.subscribed} + self._execute(connection, connection.send_command, *args, **kwargs) - def _execute(self, connection, command, *args): + def _execute(self, connection, command, *args, **kwargs): try: - return command(*args) + return command(*args, **kwargs) except (ConnectionError, TimeoutError) as e: connection.disconnect() if not (connection.retry_on_timeout and @@ -3143,18 +3154,38 @@ class PubSub(object): # the ``on_connect`` callback should haven been called by the # connection to resubscribe us to any channels and patterns we were # previously listening to - return command(*args) + return command(*args, **kwargs) def parse_response(self, block=True, timeout=0): "Parse the response from a publish/subscribe command" - connection = self.connection - if connection is None: + conn = self.connection + if conn is None: raise RuntimeError( 'pubsub connection not set: ' 'did you forget to call subscribe() or psubscribe()?') - if not block and not connection.can_read(timeout=timeout): + + self.check_health() + + if not block and not conn.can_read(timeout=timeout): + return None + response = self._execute(conn, conn.read_response) + + if conn.health_check_interval and \ + response == self.health_check_response: + # ignore the health check message as user might not expect it return None - return self._execute(connection, connection.read_response) + return response + + def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + 'pubsub connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') + + if conn.health_check_interval and time.time() > conn.next_health_check: + conn.send_command('PING', self.HEALTH_CHECK_MESSAGE, + check_health=False) def _normalize_keys(self, data): """ |