diff options
-rw-r--r-- | CHANGES | 7 | ||||
-rwxr-xr-x | redis/client.py | 53 | ||||
-rwxr-xr-x | redis/connection.py | 61 | ||||
-rw-r--r-- | tests/test_commands.py | 3 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 179 |
5 files changed, 275 insertions, 28 deletions
@@ -25,7 +25,12 @@ * Allow for single connection client instances. These instances are not thread safe but offer other benefits including a subtle performance increase. - + * Added extensive health checks that keep the connections lively. + Passing the "health_check_interval=N" option to the Redis client class + or to a ConnectionPool ensures that a round trip PING/PONG is successful + before any command if the underlying connection has been idle for more + than N seconds. ConnectionErrors and TimeoutErrors are automatically + retried once for health checks. * 3.2.1 * Fix SentinelConnectionPool to work in multiprocess/forked environments. * 3.2.0 diff --git a/redis/client.py b/redis/client.py index a38098f..80ffb68 100755 --- a/redis/client.py +++ b/redis/client.py @@ -648,7 +648,8 @@ class Redis(object): decode_responses=False, retry_on_timeout=False, ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs='required', ssl_ca_certs=None, - max_connections=None, single_connection_client=False): + max_connections=None, single_connection_client=False, + health_check_interval=0): if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -667,7 +668,8 @@ class Redis(object): 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, 'retry_on_timeout': retry_on_timeout, - 'max_connections': max_connections + 'max_connections': max_connections, + 'health_check_interval': health_check_interval, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -3053,6 +3055,7 @@ class PubSub(object): """ PUBLISH_MESSAGE_TYPES = ('message', 'pmessage') UNSUBSCRIBE_MESSAGE_TYPES = ('unsubscribe', 'punsubscribe') + HEALTH_CHECK_MESSAGE = 'redis-py-health-check' def __init__(self, connection_pool, shard_hint=None, ignore_subscribe_messages=False): @@ -3063,6 +3066,13 @@ class PubSub(object): # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE] + else: + self.health_check_response = [ + b'pong', + self.encoder.encode(self.HEALTH_CHECK_MESSAGE) + ] self.reset() def __del__(self): @@ -3111,7 +3121,7 @@ class PubSub(object): "Indicates if there are subscriptions to any channels or patterns" return bool(self.channels or self.patterns) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): "Execute a publish/subscribe command" # NOTE: don't parse the response in this function -- it could pull a @@ -3127,11 +3137,12 @@ class PubSub(object): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - self._execute(connection, connection.send_command, *args) + kwargs = {'check_health': not self.subscribed} + self._execute(connection, connection.send_command, *args, **kwargs) - def _execute(self, connection, command, *args): + def _execute(self, connection, command, *args, **kwargs): try: - return command(*args) + return command(*args, **kwargs) except (ConnectionError, TimeoutError) as e: connection.disconnect() if not (connection.retry_on_timeout and @@ -3143,18 +3154,38 @@ class PubSub(object): # the ``on_connect`` callback should haven been called by the # connection to resubscribe us to any channels and patterns we were # previously listening to - return command(*args) + return command(*args, **kwargs) def parse_response(self, block=True, timeout=0): "Parse the response from a publish/subscribe command" - connection = self.connection - if connection is None: + conn = self.connection + if conn is None: raise RuntimeError( 'pubsub connection not set: ' 'did you forget to call subscribe() or psubscribe()?') - if not block and not connection.can_read(timeout=timeout): + + self.check_health() + + if not block and not conn.can_read(timeout=timeout): + return None + response = self._execute(conn, conn.read_response) + + if conn.health_check_interval and \ + response == self.health_check_response: + # ignore the health check message as user might not expect it return None - return self._execute(connection, connection.read_response) + return response + + def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + 'pubsub connection not set: ' + 'did you forget to call subscribe() or psubscribe()?') + + if conn.health_check_interval and time.time() > conn.next_health_check: + conn.send_command('PING', self.HEALTH_CHECK_MESSAGE, + check_health=False) def _normalize_keys(self, data): """ diff --git a/redis/connection.py b/redis/connection.py index 7d4301a..9c659ac 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from distutils.version import StrictVersion from errno import EWOULDBLOCK from itertools import chain +from time import time import io import os import socket @@ -20,17 +21,17 @@ from redis._compat import (xrange, imap, byte_to_chr, unicode, long, LifoQueue, Empty, Full, urlparse, parse_qs, recv, recv_into, unquote, BlockingIOError) from redis.exceptions import ( - DataError, - RedisError, - ConnectionError, - TimeoutError, + AuthenticationError, BusyLoadingError, - ResponseError, + ConnectionError, + DataError, + ExecAbortError, InvalidResponse, - AuthenticationError, NoScriptError, - ExecAbortError, - ReadOnlyError + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, ) from redis.utils import HIREDIS_AVAILABLE if HIREDIS_AVAILABLE: @@ -460,7 +461,8 @@ class Connection(object): socket_keepalive=False, socket_keepalive_options=None, socket_type=0, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + health_check_interval=0): self.pid = os.getpid() self.host = host self.port = int(port) @@ -472,6 +474,8 @@ class Connection(object): self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) @@ -579,7 +583,9 @@ class Connection(object): # if a password is specified, authenticate if self.password: - self.send_command('AUTH', self.password) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command('AUTH', self.password, check_health=False) if nativestr(self.read_response()) != 'OK': raise AuthenticationError('Invalid Password') @@ -602,10 +608,28 @@ class Connection(object): pass self._sock = None - def send_packed_command(self, command): + def check_health(self): + "Check the health of the connection with a PING/PONG" + if self.health_check_interval and time() > self.next_health_check: + try: + self.send_command('PING', check_health=False) + if nativestr(self.read_response()) != 'PONG': + raise ConnectionError( + 'Bad response from PING health check') + except (ConnectionError, TimeoutError) as ex: + self.disconnect() + self.send_command('PING', check_health=False) + if nativestr(self.read_response()) != 'PONG': + raise ConnectionError( + 'Bad response from PING health check') + + def send_packed_command(self, command, check_health=True): "Send an already packed command to the Redis server" if not self._sock: self.connect() + # guard against health check recurrsion + if check_health: + self.check_health() try: if isinstance(command, str): command = [command] @@ -628,9 +652,10 @@ class Connection(object): self.disconnect() raise - def send_command(self, *args): + def send_command(self, *args, check_health=True): "Pack and send a command to the Redis server" - self.send_packed_command(self.pack_command(*args)) + self.send_packed_command(self.pack_command(*args), + check_health=check_health) def can_read(self, timeout=0): "Poll the socket to see if there's data that can be read." @@ -656,6 +681,10 @@ class Connection(object): except: # noqa: E722 self.disconnect() raise + + if self.health_check_interval: + self.next_health_check = time() + self.health_check_interval + if isinstance(response, ResponseError): raise response return response @@ -777,13 +806,16 @@ class UnixDomainSocketConnection(Connection): socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, retry_on_timeout=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + health_check_interval=0): self.pid = os.getpid() self.path = path self.db = db self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None self._parser = parser_class(socket_read_size=socket_read_size) @@ -829,6 +861,7 @@ URL_QUERY_ARGUMENT_PARSERS = { 'socket_keepalive': to_bool, 'retry_on_timeout': to_bool, 'max_connections': int, + 'health_check_interval': int, } diff --git a/tests/test_commands.py b/tests/test_commands.py index 931fe9c..ef316af 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -110,7 +110,8 @@ class TestRedisCommands(object): clients_by_name = dict([(client.get('name'), client) for client in clients]) - assert r.client_kill(clients_by_name['redis-py-c2'].get('addr')) is True + client_addr = clients_by_name['redis-py-c2'].get('addr') + assert r.client_kill(client_addr) is True clients = [client for client in r.client_list() if client.get('name') in ['redis-py-c1', 'redis-py-c2']] diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 0af615f..f580f71 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,12 +1,14 @@ import os +import mock import pytest +import re import redis import time -import re from threading import Thread from redis.connection import ssl_available, to_bool from .conftest import skip_if_server_version_lt, _get_client +from .test_pubsub import wait_for_message class DummyConnection(object): @@ -532,3 +534,178 @@ class TestMultiConnectionClient(object): assert not r.connection assert r.set('a', '123') assert r.get('a') == b'123' + + +class TestHealthCheck(object): + interval = 60 + + @pytest.fixture() + def r(self, request): + return _get_client(redis.Redis, request, + health_check_interval=self.interval) + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - time.time() + assert self.interval > diff > (self.interval - 1) + + def test_health_check_runs(self, r): + r.connection.next_health_check = time.time() - 1 + r.connection.check_health() + self.assert_interval_advanced(r.connection) + + def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + r.get('foo') + r.connection.next_health_check = time.time() + with mock.patch.object(r.connection, 'send_command', + wraps=r.connection.send_command) as m: + r.get('foo') + m.assert_called_with('PING', check_health=False) + + self.assert_interval_advanced(r.connection) + + def test_arbitrary_command_advances_next_health_check(self, r): + r.get('foo') + next_health_check = r.connection.next_health_check + r.get('foo') + assert next_health_check < r.connection.next_health_check + + def test_health_check_not_invoked_within_interval(self, r): + r.get('foo') + with mock.patch.object(r.connection, 'send_command', + wraps=r.connection.send_command) as m: + r.get('foo') + ping_call_spec = (('PING',), {'check_health': False}) + assert ping_call_spec not in m.call_args_list + + def test_health_check_in_pipeline(self, r): + with r.pipeline(transaction=False) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + responses = pipe.set('foo', 'bar').get('foo').execute() + m.assert_any_call('PING', check_health=False) + assert responses == [True, b'bar'] + + def test_health_check_in_transaction(self, r): + with r.pipeline(transaction=True) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + responses = pipe.set('foo', 'bar').get('foo').execute() + m.assert_any_call('PING', check_health=False) + assert responses == [True, b'bar'] + + def test_health_check_in_watched_pipeline(self, r): + r.set('foo', 'bar') + with r.pipeline(transaction=False) as pipe: + pipe.connection = pipe.connection_pool.get_connection('_') + pipe.connection.next_health_check = 0 + with mock.patch.object(pipe.connection, 'send_command', + wraps=pipe.connection.send_command) as m: + pipe.watch('foo') + # the health check should be called when watching + m.assert_called_with('PING', check_health=False) + self.assert_interval_advanced(pipe.connection) + assert pipe.get('foo') == b'bar' + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = pipe.set('foo', 'not-bar').get('foo').execute() + assert responses == [True, b'not-bar'] + m.assert_any_call('PING', check_health=False) + + def test_health_check_in_pubsub_before_subscribe(self, r): + "A health check happens before the first [p]subscribe" + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + p.connection.next_health_check = 0 + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + assert not p.subscribed + p.subscribe('foo') + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call('PING', check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + + def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + p.connection.next_health_check = 0 + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + p.subscribe('foo') + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call('PING', check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + p.subscribe('bar') + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with('SUBSCRIBE', 'bar', check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + assert wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call('PING', p.HEALTH_CHECK_MESSAGE, + check_health=False) + self.assert_interval_advanced(p.connection) + + def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = p.connection_pool.get_connection('_') + with mock.patch.object(p.connection, 'send_command', + wraps=p.connection.send_command) as m: + p.subscribe('foo') + subscribe_message = wait_for_message(p) + assert subscribe_message['type'] == 'subscribe' + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert wait_for_message(p) is None + m.assert_called_with('PING', p.HEALTH_CHECK_MESSAGE, + check_health=False) + self.assert_interval_advanced(p.connection) |