summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES7
-rwxr-xr-xredis/client.py53
-rwxr-xr-xredis/connection.py61
-rw-r--r--tests/test_commands.py3
-rw-r--r--tests/test_connection_pool.py179
5 files changed, 275 insertions, 28 deletions
diff --git a/CHANGES b/CHANGES
index 15a77b5..b8cdd3e 100644
--- a/CHANGES
+++ b/CHANGES
@@ -25,7 +25,12 @@
* Allow for single connection client instances. These instances
are not thread safe but offer other benefits including a subtle
performance increase.
-
+ * Added extensive health checks that keep the connections lively.
+ Passing the "health_check_interval=N" option to the Redis client class
+ or to a ConnectionPool ensures that a round trip PING/PONG is successful
+ before any command if the underlying connection has been idle for more
+ than N seconds. ConnectionErrors and TimeoutErrors are automatically
+ retried once for health checks.
* 3.2.1
* Fix SentinelConnectionPool to work in multiprocess/forked environments.
* 3.2.0
diff --git a/redis/client.py b/redis/client.py
index a38098f..80ffb68 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -648,7 +648,8 @@ class Redis(object):
decode_responses=False, retry_on_timeout=False,
ssl=False, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs='required', ssl_ca_certs=None,
- max_connections=None, single_connection_client=False):
+ max_connections=None, single_connection_client=False,
+ health_check_interval=0):
if not connection_pool:
if charset is not None:
warnings.warn(DeprecationWarning(
@@ -667,7 +668,8 @@ class Redis(object):
'encoding_errors': encoding_errors,
'decode_responses': decode_responses,
'retry_on_timeout': retry_on_timeout,
- 'max_connections': max_connections
+ 'max_connections': max_connections,
+ 'health_check_interval': health_check_interval,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
@@ -3053,6 +3055,7 @@ class PubSub(object):
"""
PUBLISH_MESSAGE_TYPES = ('message', 'pmessage')
UNSUBSCRIBE_MESSAGE_TYPES = ('unsubscribe', 'punsubscribe')
+ HEALTH_CHECK_MESSAGE = 'redis-py-health-check'
def __init__(self, connection_pool, shard_hint=None,
ignore_subscribe_messages=False):
@@ -3063,6 +3066,13 @@ class PubSub(object):
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = self.connection_pool.get_encoder()
+ if self.encoder.decode_responses:
+ self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE]
+ else:
+ self.health_check_response = [
+ b'pong',
+ self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
+ ]
self.reset()
def __del__(self):
@@ -3111,7 +3121,7 @@ class PubSub(object):
"Indicates if there are subscriptions to any channels or patterns"
return bool(self.channels or self.patterns)
- def execute_command(self, *args, **kwargs):
+ def execute_command(self, *args):
"Execute a publish/subscribe command"
# NOTE: don't parse the response in this function -- it could pull a
@@ -3127,11 +3137,12 @@ class PubSub(object):
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
- self._execute(connection, connection.send_command, *args)
+ kwargs = {'check_health': not self.subscribed}
+ self._execute(connection, connection.send_command, *args, **kwargs)
- def _execute(self, connection, command, *args):
+ def _execute(self, connection, command, *args, **kwargs):
try:
- return command(*args)
+ return command(*args, **kwargs)
except (ConnectionError, TimeoutError) as e:
connection.disconnect()
if not (connection.retry_on_timeout and
@@ -3143,18 +3154,38 @@ class PubSub(object):
# the ``on_connect`` callback should haven been called by the
# connection to resubscribe us to any channels and patterns we were
# previously listening to
- return command(*args)
+ return command(*args, **kwargs)
def parse_response(self, block=True, timeout=0):
"Parse the response from a publish/subscribe command"
- connection = self.connection
- if connection is None:
+ conn = self.connection
+ if conn is None:
raise RuntimeError(
'pubsub connection not set: '
'did you forget to call subscribe() or psubscribe()?')
- if not block and not connection.can_read(timeout=timeout):
+
+ self.check_health()
+
+ if not block and not conn.can_read(timeout=timeout):
+ return None
+ response = self._execute(conn, conn.read_response)
+
+ if conn.health_check_interval and \
+ response == self.health_check_response:
+ # ignore the health check message as user might not expect it
return None
- return self._execute(connection, connection.read_response)
+ return response
+
+ def check_health(self):
+ conn = self.connection
+ if conn is None:
+ raise RuntimeError(
+ 'pubsub connection not set: '
+ 'did you forget to call subscribe() or psubscribe()?')
+
+ if conn.health_check_interval and time.time() > conn.next_health_check:
+ conn.send_command('PING', self.HEALTH_CHECK_MESSAGE,
+ check_health=False)
def _normalize_keys(self, data):
"""
diff --git a/redis/connection.py b/redis/connection.py
index 7d4301a..9c659ac 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -2,6 +2,7 @@ from __future__ import unicode_literals
from distutils.version import StrictVersion
from errno import EWOULDBLOCK
from itertools import chain
+from time import time
import io
import os
import socket
@@ -20,17 +21,17 @@ from redis._compat import (xrange, imap, byte_to_chr, unicode, long,
LifoQueue, Empty, Full, urlparse, parse_qs,
recv, recv_into, unquote, BlockingIOError)
from redis.exceptions import (
- DataError,
- RedisError,
- ConnectionError,
- TimeoutError,
+ AuthenticationError,
BusyLoadingError,
- ResponseError,
+ ConnectionError,
+ DataError,
+ ExecAbortError,
InvalidResponse,
- AuthenticationError,
NoScriptError,
- ExecAbortError,
- ReadOnlyError
+ ReadOnlyError,
+ RedisError,
+ ResponseError,
+ TimeoutError,
)
from redis.utils import HIREDIS_AVAILABLE
if HIREDIS_AVAILABLE:
@@ -460,7 +461,8 @@ class Connection(object):
socket_keepalive=False, socket_keepalive_options=None,
socket_type=0, retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
- parser_class=DefaultParser, socket_read_size=65536):
+ parser_class=DefaultParser, socket_read_size=65536,
+ health_check_interval=0):
self.pid = os.getpid()
self.host = host
self.port = int(port)
@@ -472,6 +474,8 @@ class Connection(object):
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
+ self.health_check_interval = health_check_interval
+ self.next_health_check = 0
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
@@ -579,7 +583,9 @@ class Connection(object):
# if a password is specified, authenticate
if self.password:
- self.send_command('AUTH', self.password)
+ # avoid checking health here -- PING will fail if we try
+ # to check the health prior to the AUTH
+ self.send_command('AUTH', self.password, check_health=False)
if nativestr(self.read_response()) != 'OK':
raise AuthenticationError('Invalid Password')
@@ -602,10 +608,28 @@ class Connection(object):
pass
self._sock = None
- def send_packed_command(self, command):
+ def check_health(self):
+ "Check the health of the connection with a PING/PONG"
+ if self.health_check_interval and time() > self.next_health_check:
+ try:
+ self.send_command('PING', check_health=False)
+ if nativestr(self.read_response()) != 'PONG':
+ raise ConnectionError(
+ 'Bad response from PING health check')
+ except (ConnectionError, TimeoutError) as ex:
+ self.disconnect()
+ self.send_command('PING', check_health=False)
+ if nativestr(self.read_response()) != 'PONG':
+ raise ConnectionError(
+ 'Bad response from PING health check')
+
+ def send_packed_command(self, command, check_health=True):
"Send an already packed command to the Redis server"
if not self._sock:
self.connect()
+ # guard against health check recurrsion
+ if check_health:
+ self.check_health()
try:
if isinstance(command, str):
command = [command]
@@ -628,9 +652,10 @@ class Connection(object):
self.disconnect()
raise
- def send_command(self, *args):
+ def send_command(self, *args, check_health=True):
"Pack and send a command to the Redis server"
- self.send_packed_command(self.pack_command(*args))
+ self.send_packed_command(self.pack_command(*args),
+ check_health=check_health)
def can_read(self, timeout=0):
"Poll the socket to see if there's data that can be read."
@@ -656,6 +681,10 @@ class Connection(object):
except: # noqa: E722
self.disconnect()
raise
+
+ if self.health_check_interval:
+ self.next_health_check = time() + self.health_check_interval
+
if isinstance(response, ResponseError):
raise response
return response
@@ -777,13 +806,16 @@ class UnixDomainSocketConnection(Connection):
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
- parser_class=DefaultParser, socket_read_size=65536):
+ parser_class=DefaultParser, socket_read_size=65536,
+ health_check_interval=0):
self.pid = os.getpid()
self.path = path
self.db = db
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
+ self.health_check_interval = health_check_interval
+ self.next_health_check = 0
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
@@ -829,6 +861,7 @@ URL_QUERY_ARGUMENT_PARSERS = {
'socket_keepalive': to_bool,
'retry_on_timeout': to_bool,
'max_connections': int,
+ 'health_check_interval': int,
}
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 931fe9c..ef316af 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -110,7 +110,8 @@ class TestRedisCommands(object):
clients_by_name = dict([(client.get('name'), client)
for client in clients])
- assert r.client_kill(clients_by_name['redis-py-c2'].get('addr')) is True
+ client_addr = clients_by_name['redis-py-c2'].get('addr')
+ assert r.client_kill(client_addr) is True
clients = [client for client in r.client_list()
if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py
index 0af615f..f580f71 100644
--- a/tests/test_connection_pool.py
+++ b/tests/test_connection_pool.py
@@ -1,12 +1,14 @@
import os
+import mock
import pytest
+import re
import redis
import time
-import re
from threading import Thread
from redis.connection import ssl_available, to_bool
from .conftest import skip_if_server_version_lt, _get_client
+from .test_pubsub import wait_for_message
class DummyConnection(object):
@@ -532,3 +534,178 @@ class TestMultiConnectionClient(object):
assert not r.connection
assert r.set('a', '123')
assert r.get('a') == b'123'
+
+
+class TestHealthCheck(object):
+ interval = 60
+
+ @pytest.fixture()
+ def r(self, request):
+ return _get_client(redis.Redis, request,
+ health_check_interval=self.interval)
+
+ def assert_interval_advanced(self, connection):
+ diff = connection.next_health_check - time.time()
+ assert self.interval > diff > (self.interval - 1)
+
+ def test_health_check_runs(self, r):
+ r.connection.next_health_check = time.time() - 1
+ r.connection.check_health()
+ self.assert_interval_advanced(r.connection)
+
+ def test_arbitrary_command_invokes_health_check(self, r):
+ # invoke a command to make sure the connection is entirely setup
+ r.get('foo')
+ r.connection.next_health_check = time.time()
+ with mock.patch.object(r.connection, 'send_command',
+ wraps=r.connection.send_command) as m:
+ r.get('foo')
+ m.assert_called_with('PING', check_health=False)
+
+ self.assert_interval_advanced(r.connection)
+
+ def test_arbitrary_command_advances_next_health_check(self, r):
+ r.get('foo')
+ next_health_check = r.connection.next_health_check
+ r.get('foo')
+ assert next_health_check < r.connection.next_health_check
+
+ def test_health_check_not_invoked_within_interval(self, r):
+ r.get('foo')
+ with mock.patch.object(r.connection, 'send_command',
+ wraps=r.connection.send_command) as m:
+ r.get('foo')
+ ping_call_spec = (('PING',), {'check_health': False})
+ assert ping_call_spec not in m.call_args_list
+
+ def test_health_check_in_pipeline(self, r):
+ with r.pipeline(transaction=False) as pipe:
+ pipe.connection = pipe.connection_pool.get_connection('_')
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(pipe.connection, 'send_command',
+ wraps=pipe.connection.send_command) as m:
+ responses = pipe.set('foo', 'bar').get('foo').execute()
+ m.assert_any_call('PING', check_health=False)
+ assert responses == [True, b'bar']
+
+ def test_health_check_in_transaction(self, r):
+ with r.pipeline(transaction=True) as pipe:
+ pipe.connection = pipe.connection_pool.get_connection('_')
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(pipe.connection, 'send_command',
+ wraps=pipe.connection.send_command) as m:
+ responses = pipe.set('foo', 'bar').get('foo').execute()
+ m.assert_any_call('PING', check_health=False)
+ assert responses == [True, b'bar']
+
+ def test_health_check_in_watched_pipeline(self, r):
+ r.set('foo', 'bar')
+ with r.pipeline(transaction=False) as pipe:
+ pipe.connection = pipe.connection_pool.get_connection('_')
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(pipe.connection, 'send_command',
+ wraps=pipe.connection.send_command) as m:
+ pipe.watch('foo')
+ # the health check should be called when watching
+ m.assert_called_with('PING', check_health=False)
+ self.assert_interval_advanced(pipe.connection)
+ assert pipe.get('foo') == b'bar'
+
+ # reset the mock to clear the call list and schedule another
+ # health check
+ m.reset_mock()
+ pipe.connection.next_health_check = 0
+
+ pipe.multi()
+ responses = pipe.set('foo', 'not-bar').get('foo').execute()
+ assert responses == [True, b'not-bar']
+ m.assert_any_call('PING', check_health=False)
+
+ def test_health_check_in_pubsub_before_subscribe(self, r):
+ "A health check happens before the first [p]subscribe"
+ p = r.pubsub()
+ p.connection = p.connection_pool.get_connection('_')
+ p.connection.next_health_check = 0
+ with mock.patch.object(p.connection, 'send_command',
+ wraps=p.connection.send_command) as m:
+ assert not p.subscribed
+ p.subscribe('foo')
+ # the connection is not yet in pubsub mode, so the normal
+ # ping/pong within connection.send_command should check
+ # the health of the connection
+ m.assert_any_call('PING', check_health=False)
+ self.assert_interval_advanced(p.connection)
+
+ subscribe_message = wait_for_message(p)
+ assert subscribe_message['type'] == 'subscribe'
+
+ def test_health_check_in_pubsub_after_subscribed(self, r):
+ """
+ Pubsub can handle a new subscribe when it's time to check the
+ connection health
+ """
+ p = r.pubsub()
+ p.connection = p.connection_pool.get_connection('_')
+ p.connection.next_health_check = 0
+ with mock.patch.object(p.connection, 'send_command',
+ wraps=p.connection.send_command) as m:
+ p.subscribe('foo')
+ subscribe_message = wait_for_message(p)
+ assert subscribe_message['type'] == 'subscribe'
+ self.assert_interval_advanced(p.connection)
+ # because we weren't subscribed when sending the subscribe
+ # message to 'foo', the connection's standard check_health ran
+ # prior to subscribing.
+ m.assert_any_call('PING', check_health=False)
+
+ p.connection.next_health_check = 0
+ m.reset_mock()
+
+ p.subscribe('bar')
+ # the second subscribe issues exactly only command (the subscribe)
+ # and the health check is not invoked
+ m.assert_called_once_with('SUBSCRIBE', 'bar', check_health=False)
+
+ # since no message has been read since the health check was
+ # reset, it should still be 0
+ assert p.connection.next_health_check == 0
+
+ subscribe_message = wait_for_message(p)
+ assert subscribe_message['type'] == 'subscribe'
+ assert wait_for_message(p) is None
+ # now that the connection is subscribed, the pubsub health
+ # check should have taken over and include the HEALTH_CHECK_MESSAGE
+ m.assert_any_call('PING', p.HEALTH_CHECK_MESSAGE,
+ check_health=False)
+ self.assert_interval_advanced(p.connection)
+
+ def test_health_check_in_pubsub_poll(self, r):
+ """
+ Polling a pubsub connection that's subscribed will regularly
+ check the connection's health.
+ """
+ p = r.pubsub()
+ p.connection = p.connection_pool.get_connection('_')
+ with mock.patch.object(p.connection, 'send_command',
+ wraps=p.connection.send_command) as m:
+ p.subscribe('foo')
+ subscribe_message = wait_for_message(p)
+ assert subscribe_message['type'] == 'subscribe'
+ self.assert_interval_advanced(p.connection)
+
+ # polling the connection before the health check interval
+ # doesn't result in another health check
+ m.reset_mock()
+ next_health_check = p.connection.next_health_check
+ assert wait_for_message(p) is None
+ assert p.connection.next_health_check == next_health_check
+ m.assert_not_called()
+
+ # reset the health check and poll again
+ # we should not receive a pong message, but the next_health_check
+ # should be advanced
+ p.connection.next_health_check = 0
+ assert wait_for_message(p) is None
+ m.assert_called_with('PING', p.HEALTH_CHECK_MESSAGE,
+ check_health=False)
+ self.assert_interval_advanced(p.connection)