summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2013-04-22 14:12:25 -0700
committerAndy McCurdy <andy@andymccurdy.com>2013-04-22 14:12:25 -0700
commit73f898c5534407414f250d15535da9f616a468b3 (patch)
tree59ab9e125e3658a0c947e845d35a1e36675060d9
parent80dda11c934b4de63f3027c84cea4e1add47d37e (diff)
parent5586443c04a3e939cabec778a920eaa97991c3dc (diff)
downloadredis-py-73f898c5534407414f250d15535da9f616a468b3.tar.gz
Merge pull request #340 from thruflo/master
Add optional `redis.connection.BlockingConnectionPool` class.
-rw-r--r--redis/__init__.py1
-rw-r--r--redis/_compat.py27
-rw-r--r--redis/connection.py164
-rw-r--r--tests/__init__.py2
-rw-r--r--tests/connection_pool.py76
5 files changed, 269 insertions, 1 deletions
diff --git a/redis/__init__.py b/redis/__init__.py
index 4b10834..5d9f3f1 100644
--- a/redis/__init__.py
+++ b/redis/__init__.py
@@ -1,5 +1,6 @@
from redis.client import Redis, StrictRedis
from redis.connection import (
+ BlockingConnectionPool,
ConnectionPool,
Connection,
UnixDomainSocketConnection
diff --git a/redis/_compat.py b/redis/_compat.py
index b04098c..4a8a14c 100644
--- a/redis/_compat.py
+++ b/redis/_compat.py
@@ -48,3 +48,30 @@ else:
unicode = str
bytes = bytes
long = int
+
+try: # Python 3
+ from queue import LifoQueue, Empty, Full
+except ImportError:
+ from Queue import Empty, Full
+ try: # Python 2.6 - 2.7
+ from Queue import LifoQueue
+ except ImportError: # Python 2.5
+ from Queue import Queue
+ # From the Python 2.7 lib. Python 2.5 already extracted the core
+ # methods to aid implementating different queue organisations.
+ class LifoQueue(Queue):
+ """Override queue methods to implement a last-in first-out queue."""
+
+ def _init(self, maxsize):
+ self.maxsize = maxsize
+ self.queue = []
+
+ def _qsize(self, len=len):
+ return len(self.queue)
+
+ def _put(self, item):
+ self.queue.append(item)
+
+ def _get(self):
+ return self.queue.pop()
+
diff --git a/redis/connection.py b/redis/connection.py
index 60dfde6..ccfedf6 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -4,7 +4,8 @@ import socket
import sys
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
- BytesIO, nativestr, basestring)
+ BytesIO, nativestr, basestring,
+ LifoQueue, Empty, Full)
from redis.exceptions import (
RedisError,
ConnectionError,
@@ -418,3 +419,164 @@ class ConnectionPool(object):
self._in_use_connections)
for connection in all_conns:
connection.disconnect()
+
+
+class BlockingConnectionPool(object):
+ """Thread-safe blocking connection pool::
+
+ >>> from redis.client import Redis
+ >>> client = Redis(connection_pool=BlockingConnectionPool())
+
+ It performs the same function as the default
+ ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
+ it maintains a pool of reusable connections that can be shared by
+ multiple redis clients (safely across threads if required).
+
+ The difference is that, in the event that a client tries to get a
+ connection from the pool when all of connections are in use, rather than
+ raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
+ ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
+ makes the client wait ("blocks") for a specified number of seconds until
+ a connection becomes available.
+
+ Use ``max_connections`` to increase / decrease the pool size::
+
+ >>> pool = BlockingConnectionPool(max_connections=10)
+
+ Use ``timeout`` to tell it either how many seconds to wait for a connection
+ to become available, or to block forever:
+
+ # Block forever.
+ >>> pool = BlockingConnectionPool(timeout=None)
+
+ # Raise a ``ConnectionError`` after five seconds if a connection is
+ # not available.
+ >>> pool = BlockingConnectionPool(timeout=5)
+
+ """
+
+ def __init__(self, max_connections=50, timeout=20, connection_class=None, queue_class=None, **connection_kwargs):
+ """Compose and assign values."""
+
+ # Compose.
+ if connection_class is None:
+ connection_class = Connection
+ if queue_class is None:
+ queue_class = LifoQueue
+
+ # Assign.
+ self.connection_class = connection_class
+ self.connection_kwargs = connection_kwargs
+ self.queue_class = queue_class
+ self.max_connections = max_connections
+ self.timeout = timeout
+
+ # Validate the ``max_connections``. With the "fill up the queue"
+ # algorithm we use, it must be a positive integer.
+ is_valid = isinstance(max_connections, int) and max_connections > 0
+ if not is_valid:
+ raise ValueError('``max_connections`` must be a positive integer')
+
+ # Get the current process id, so we can disconnect and reinstantiate if
+ # it changes.
+ self.pid = os.getpid()
+
+ # Create and fill up a thread safe queue with ``None`` values.
+ self.pool = self.queue_class(max_connections)
+ while True:
+ try:
+ self.pool.put_nowait(None)
+ except Full:
+ break
+
+ # Keep a list of actual connection instances so that we can
+ # disconnect them later.
+ self._connections = []
+
+ def _checkpid(self):
+ """Check the current process id. If it has changed, disconnect and
+ re-instantiate this connection pool instance.
+ """
+
+ # Get the current process id.
+ pid = os.getpid()
+
+ # If it hasn't changed since we were instantiated, then we're fine, so
+ # just exit, remaining connected.
+ if self.pid == pid:
+ return
+
+ # If it has changed, then disconnect and re-instantiate.
+ self.disconnect()
+ self.reinstantiate()
+
+ def make_connection(self):
+ """Make a fresh connection."""
+
+ connection = self.connection_class(**self.connection_kwargs)
+ self._connections.append(connection)
+ return connection
+
+ def get_connection(self, command_name, *keys, **options):
+ """Get a connection, blocking for ``self.timeout`` until a connection is
+ available from the pool.
+
+ If the connection returned is ``None`` then creates a new connection.
+ Because we use a last-in first-out queue, the existing connections (having
+ been returned to the pool after the initial ``None`` values were added)
+ will be returned before ``None`` values. This means we only create new
+ connections when we need to, i.e.: the actual number of connections will
+ only increase in response to demand.
+ """
+
+ # Make sure we haven't changed process.
+ self._checkpid()
+
+ # Try and get a connection from the pool. If one isn't available within
+ # self.timeout then raise a ``ConnectionError``.
+ connection = None
+ try:
+ connection = self.pool.get(block=True, timeout=self.timeout)
+ except Empty:
+ # Note that this is not caught by the redis client and will be raised
+ # unless handled by application code. If you want never to raise
+ raise ConnectionError("No connection available.")
+
+ # If the ``connection`` is actually ``None`` then that's a cue to make a
+ # new connection to add to the pool.
+ if connection is None:
+ connection = self.make_connection()
+
+ return connection
+
+ def release(self, connection):
+ """Releases the connection back to the pool."""
+
+ # Make sure we haven't changed process.
+ self._checkpid()
+
+ # Put the connection back into the pool.
+ try:
+ self.pool.put_nowait(connection)
+ except Full:
+ # This shouldn't normally happen but might perhaps happen after a
+ # reinstantiation. So, we can handle the exception by not putting
+ # the connection back on the pool, because we definitely do not
+ # want to reuse it.
+ pass
+
+ def disconnect(self):
+ """Disconnects all connections in the pool."""
+
+ for connection in self._connections:
+ connection.disconnect()
+
+ def reinstantiate(self):
+ """Reinstatiate this instance within a new process with a new connection
+ pool set.
+ """
+
+ self.__init__(max_connections=self.max_connections, timeout=self.timeout,
+ connection_class=self.connection_class,
+ queue_class=self.queue_class, **self.connection_kwargs)
+
diff --git a/tests/__init__.py b/tests/__init__.py
index 616bdd4..71c0179 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -2,6 +2,7 @@ import unittest
from tests.server_commands import ServerCommandsTestCase
from tests.connection_pool import ConnectionPoolTestCase
+from tests.connection_pool import BlockingConnectionPoolTestCase
from tests.pipeline import PipelineTestCase
from tests.lock import LockTestCase
from tests.pubsub import PubSubTestCase, PubSubRedisDownTestCase
@@ -19,6 +20,7 @@ def all_tests():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ServerCommandsTestCase))
suite.addTest(unittest.makeSuite(ConnectionPoolTestCase))
+ suite.addTest(unittest.makeSuite(BlockingConnectionPoolTestCase))
suite.addTest(unittest.makeSuite(PipelineTestCase))
suite.addTest(unittest.makeSuite(LockTestCase))
suite.addTest(unittest.makeSuite(PubSubTestCase))
diff --git a/tests/connection_pool.py b/tests/connection_pool.py
index 4f0bc0c..91d4d46 100644
--- a/tests/connection_pool.py
+++ b/tests/connection_pool.py
@@ -36,9 +36,85 @@ class ConnectionPoolTestCase(unittest.TestCase):
c2 = pool.get_connection('_')
self.assertRaises(redis.ConnectionError, pool.get_connection, '_')
+ def test_blocking_max_connections(self):
+ pool = self.get_pool(max_connections=2)
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+ self.assertRaises(redis.ConnectionError, pool.get_connection, '_')
+
+ def test_release(self):
+ pool = self.get_pool()
+ c1 = pool.get_connection('_')
+ pool.release(c1)
+ c2 = pool.get_connection('_')
+ self.assertEquals(c1, c2)
+
+
+class BlockingConnectionPoolTestCase(unittest.TestCase):
+ def get_pool(self, connection_info=None, max_connections=10, timeout=20):
+ connection_info = connection_info or {'a': 1, 'b': 2, 'c': 3}
+ pool = redis.BlockingConnectionPool(connection_class=DummyConnection,
+ max_connections=max_connections,
+ timeout=timeout, **connection_info)
+ return pool
+
+ def test_connection_creation(self):
+ connection_info = {'foo': 'bar', 'biz': 'baz'}
+ pool = self.get_pool(connection_info=connection_info)
+ connection = pool.get_connection('_')
+ self.assertEquals(connection.kwargs, connection_info)
+
+ def test_multiple_connections(self):
+ pool = self.get_pool()
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+ self.assert_(c1 != c2)
+
+ def test_max_connections_blocks(self):
+ """Getting a connection should block for until available."""
+
+ import time
+ from copy import deepcopy
+ from threading import Thread
+
+ # We use a queue for cross thread communication within the unit test.
+ try: # Python 3
+ from queue import Queue
+ except ImportError:
+ from Queue import Queue
+
+ q = Queue()
+ q.put_nowait('Not yet got')
+ pool = self.get_pool(max_connections=2, timeout=5)
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+
+ target = lambda: q.put_nowait(pool.get_connection('_'))
+ Thread(target=target).start()
+
+ # Blocks while non available.
+ time.sleep(0.05)
+ c3 = q.get_nowait()
+ self.assertEquals(c3, 'Not yet got')
+
+ # Then got when available.
+ pool.release(c1)
+ time.sleep(0.05)
+ c3 = q.get_nowait()
+ self.assertEquals(c1, c3)
+
+ def test_max_connections_timeout(self):
+ """Getting a connection raises ``ConnectionError`` after timeout."""
+
+ pool = self.get_pool(max_connections=2, timeout=0.1)
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+ self.assertRaises(redis.ConnectionError, pool.get_connection, '_')
+
def test_release(self):
pool = self.get_pool()
c1 = pool.get_connection('_')
pool.release(c1)
c2 = pool.get_connection('_')
self.assertEquals(c1, c2)
+