From 7a73f197b18584710c9f3b9ee8cbbadbe727bf34 Mon Sep 17 00:00:00 2001 From: Andy McCurdy Date: Tue, 23 Jul 2019 16:54:26 -0700 Subject: PING/PONG health checks The `Redis` class and the `ConnectionPool` class now support the "health_check_interval=N" option. By default N=0, which turns off health checks. `N` should be an integer, and when greater than 0, ensures that a health check is performed just before command execution anytime the underlying connection has been idle for more than N seconds. A health check is a full PING/PONG round trip to the Redis server. If a health check encounters a ConnectionError or TimeoutError, the connection is disconnected and reconnected and the health check is retried exactly once. Any error during the retry is raised to the caller. Health check retries are not governed by any other options such as `retry_on_timeout`. In systems where idle times are common, these health checks are the intended way to reconnect to the Redis server without harming any user data. When this option is enabled for PubSub connections, calling `get_message()` or `listen()` will send a health check anytime a message has not been read on the PubSub connection for `health_check_interval` seconds. Users should call `get_message()` or `listen()` at least every `health_check_interval` seconds in order to keep the connection open. --- CHANGES | 7 +- redis/client.py | 53 ++++++++++--- redis/connection.py | 61 ++++++++++---- tests/test_commands.py | 3 +- tests/test_connection_pool.py | 179 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 275 insertions(+), 28 deletions(-) diff --git a/CHANGES b/CHANGES index 15a77b5..b8cdd3e 100644 --- a/CHANGES +++ b/CHANGES @@ -25,7 +25,12 @@ * Allow for single connection client instances. These instances are not thread safe but offer other benefits including a subtle performance increase. - + * Added extensive health checks that keep the connections lively. + Passing the "health_check_interval=N" option to the Redis client class + or to a ConnectionPool ensures that a round trip PING/PONG is successful + before any command if the underlying connection has been idle for more + than N seconds. ConnectionErrors and TimeoutErrors are automatically + retried once for health checks. * 3.2.1 * Fix SentinelConnectionPool to work in multiprocess/forked environments. * 3.2.0 diff --git a/redis/client.py b/redis/client.py index a38098f..80ffb68 100755 --- a/redis/client.py +++ b/redis/client.py @@ -648,7 +648,8 @@ class Redis(object): decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, - max_connections=None, single_connection_client=False): + max_connections=None, single_connection_client=False, + health_check_interval=0): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -667,7 +668,8 @@ class Redis(object): 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, 'retry_on_timeout': retry_on_timeout, - 'max_connections': max_connections + 'max_connections': max_connections, + 'health_check_interval': health_check_interval, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -3053,6 +3055,7 @@ class PubSub(object): """ PUBLISH_MESSAGE_TYPES = ('message', 'pmessage') UNSUBSCRIBE_MESSAGE_TYPES = ('unsubscribe', 'punsubscribe') + HEALTH_CHECK_MESSAGE = 'redis-py-health-check' def __init__(self, connection_pool, shard_hint=None, ignore_subscribe_messages=False): @@ -3063,6 +3066,13 @@ class PubSub(object): # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE] + else: + self.health_check_response = [ + b'pong', + self.encoder.encode(self.HEALTH_CHECK_MESSAGE) + ] self.reset() def __del__(self): @@ -3111,7 +3121,7 @@ class PubSub(object): "Indicates if there are subscriptions to any channels or patterns" return bool(self.channels or self.patterns) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): "Execute a publish/subscribe command" # NOTE: don't parse the response in this function -- it could pull a @@ -3127,11 +3137,12 @@ class PubSub(object): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - self._execute(connection, connection.send_command, *args) + kwargs = {'check_health': not self.subscribed} + self._execute(connection, connection.send_command, *args, **kwargs) - def _execute(self, connection, command, *args): + def _execute(self, connection, command, *args, **kwargs): try: - return command(*args) + return command(*args, **kwargs) except (ConnectionError, TimeoutError) as e: connection.disconnect() if not (connection.retry_on_timeout and @@ -3143,18 +3154,38 @@ class PubSub(object): # the ``on_connect`` callback should haven been called by the # connection to resubscribe us to any channels and patterns we were # previously listening to - return command(*args) + return command(*args, **kwargs) def parse_response(self, block=True, timeout=0): "Parse the response from a publish/subscribe command" - connection = self.connection - if connection is None: + conn = self.connection + if conn is None: raise RuntimeError( 'pubsub connection not set: ' 'did you forget to call subscribe() or psubscribe()?') - if not block and not connection.can_read(timeout=timeout): + + self.check_health() + + if not block and not conn.can_read(timeout=timeout): + return None + response = self._execute(conn, conn.read_response) + + if conn.health_check_interval and \ + response == self.health_check_response: + # ignore the health check message as user might not expect it return None - return self._execute(connection, connection.read_response) + return response + + def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + 'pubsub connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') + + if conn.health_check_interval and time.time() > conn.next_health_check: + conn.send_command('PING', self.HEALTH_CHECK_MESSAGE, + check_health=False) def _normalize_keys(self, data): """ diff --git a/redis/connection.py b/redis/connection.py index 7d4301a..9c659ac 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from distutils.version import StrictVersion from errno import EWOULDBLOCK from itertools import chain +from time import time import io import os import socket @@ -20,17 +21,17 @@ from redis._compat import (xrange, imap, byte_to_chr, unicode, long, LifoQueue, Empty, Full, urlparse, parse_qs, recv, recv_into, unquote, BlockingIOError) from redis.exceptions import ( - DataError, - RedisError, - ConnectionError, - TimeoutError, + AuthenticationError, BusyLoadingError, - ResponseError, + ConnectionError, + DataError, + ExecAbortError, InvalidResponse, - AuthenticationError, NoScriptError, - ExecAbortError, - ReadOnlyError + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, ) from redis.utils import HIREDIS_AVAILABLE if HIREDIS_AVAILABLE: @@ -460,7 +461,8 @@ class Connection(object): socket_keepalive=False, socket_keepalive_options=None, socket_type=0, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + health_check_interval=0): self.pid = os.getpid() self.host = host self.port = int(port) @@ -472,6 +474,8 @@ class Connection(object): self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) @@ -579,7 +583,9 @@ class Connection(object): # if a password is specified, authenticate if self.password: - self.send_command('AUTH', self.password) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command('AUTH', self.password, check_health=False) if nativestr(self.read_response()) != 'OK': raise AuthenticationError('Invalid Password') @@ -602,10 +608,28 @@ class Connection(object): pass self._sock = None - def send_packed_command(self, command): + def check_health(self): + "Check the health of the connection with a PING/PONG" + if self.health_check_interval and time() > self.next_health_check: + try: + self.send_command('PING', check_health=False) + if nativestr(self.read_response()) != 'PONG': + raise ConnectionError( + 'Bad response from PING health check') + except (ConnectionError, TimeoutError) as ex: + self.disconnect() + self.send_command('PING', check_health=False) + if nativestr(self.read_response()) != 'PONG': + raise ConnectionError( + 'Bad response from PING health check') + + def send_packed_command(self, command, check_health=True): "Send an already packed command to the Redis server" if not self._sock: self.connect() + # guard against health check recurrsion + if check_health: + self.check_health() try: if isinstance(command, str): command = [command] @@ -628,9 +652,10 @@ class Connection(object): self.disconnect() raise - def send_command(self, *args): + def send_command(self, *args, check_health=True): "Pack and send a command to the Redis server" - self.send_packed_command(self.pack_command(*args)) + self.send_packed_command(self.pack_command(*args), + check_health=check_health) def can_read(self, timeout=0): "Poll the socket to see if there's data that can be read." @@ -656,6 +681,10 @@ class Connection(object): except: # noqa: E722 self.disconnect() raise + + if self.health_check_interval: + self.next_health_check = time() + self.health_check_interval + if isinstance(response, ResponseError): raise response return response @@ -777,13 +806,16 @@ class UnixDomainSocketConnection(Connection): socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, retry_on_timeout=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + health_check_interval=0): self.pid = os.getpid() self.path = path self.db = db self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) @@ -829,6 +861,7 @@ URL_QUERY_ARGUMENT_PARSERS = { 'socket_keepalive': to_bool, 'retry_on_timeout': to_bool, 'max_connections': int, + 'health_check_interval': int, } diff --git a/tests/test_commands.py b/tests/test_commands.py index 931fe9c..ef316af 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -110,7 +110,8 @@ class TestRedisCommands(object): clients_by_name = dict([(client.get('name'), client) for client in clients]) - assert r.client_kill(clients_by_name['redis-py-c2'].get('addr')) is True + client_addr = clients_by_name['redis-py-c2'].get('addr') + assert r.client_kill(client_addr) is True clients = [client for client in r.client_list() if client.get('name') in ['redis-py-c1', 'redis-py-c2']] diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 0af615f..f580f71 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,12 +1,14 @@ import os +import mock import pytest +import re import redis import time -import re from threading import Thread from redis.connection import ssl_available, to_bool from .conftest import skip_if_server_version_lt, _get_client +from .test_pubsub import wait_for_message class DummyConnection(object): @@ -532,3 +534,178 @@ class TestMultiConnectionClient(object): assert not r.connection assert r.set('a', '123') assert r.get('a') == b'123' + + +class TestHealthCheck(object): + interval = 60 + + @pytest.fixture() + def r(self, request): + return _get_client(redis.Redis, request, + health_check_interval=self.interval) + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - time.time() + assert self.interval > diff > (self.interval - 1) + + def test_health_check_runs(self, r): + r.connection.next_health_check = time.time() - 1 + r.connection.check_health() + self.assert_interval_advanced(r.connection) + + def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + r.get('foo') + r.connection.next_health_check = time.time() + with mock.patch.object(r.connection, 'send_command', + wraps=r.connection.send_command) as m: + r.get('foo') + m.assert_called_with('PING', check_health=False) + + self.assert_interval_advanced(r.connection) + + def test_arbitrary_command_advances_next_health_check(self, r): + r.get('foo') + next_health_check = r.connection.next_health_check + r.get('foo') + assert next_health_check < r.connection.next_health_check + + def test_health_check_not_invoked_within_interval(self, r): + r.get('foo') + with mock.patch.object(r.connection, 'send_command', + wraps=r.connection.send_command) as m: + r.get('foo') + ping_call_spec = (('PING',), {'check_health': False}) + assert ping_call_spec not in m.call_args_list + + def test_health_check_in_pipeline(self, r): + with r.pipeline(transaction=False) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + responses = pipe.set('foo', 'bar').get('foo').execute() + m.assert_any_call('PING', check_health=False) + assert responses == [True, b'bar'] + + def test_health_check_in_transaction(self, r): + with r.pipeline(transaction=True) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + responses = pipe.set('foo', 'bar').get('foo').execute() + m.assert_any_call('PING', check_health=False) + assert responses == [True, b'bar'] + + def test_health_check_in_watched_pipeline(self, r): + r.set('foo', 'bar') + with r.pipeline(transaction=False) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + pipe.watch('foo') + # the health check should be called when watching + m.assert_called_with('PING', check_health=False) + self.assert_interval_advanced(pipe.connection) + assert pipe.get('foo') == b'bar' + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = pipe.set('foo', 'not-bar').get('foo').execute() + assert responses == [True, b'not-bar'] + m.assert_any_call('PING', check_health=False) + + def test_health_check_in_pubsub_before_subscribe(self, r): + "A health check happens before the first [p]subscribe" + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + p.connection.next_health_check = 0 + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + assert not p.subscribed + p.subscribe('foo') + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call('PING', check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + + def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + p.connection.next_health_check = 0 + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + p.subscribe('foo') + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call('PING', check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + p.subscribe('bar') + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with('SUBSCRIBE', 'bar', check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + assert wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call('PING', p.HEALTH_CHECK_MESSAGE, + check_health=False) + self.assert_interval_advanced(p.connection) + + def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + p.subscribe('foo') + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert wait_for_message(p) is None + m.assert_called_with('PING', p.HEALTH_CHECK_MESSAGE, + check_health=False) + self.assert_interval_advanced(p.connection) -- cgit v1.2.1