diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2013-04-22 14:12:25 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2013-04-22 14:12:25 -0700 |
commit | 73f898c5534407414f250d15535da9f616a468b3 (patch) | |
tree | 59ab9e125e3658a0c947e845d35a1e36675060d9 | |
parent | 80dda11c934b4de63f3027c84cea4e1add47d37e (diff) | |
parent | 5586443c04a3e939cabec778a920eaa97991c3dc (diff) | |
download | redis-py-73f898c5534407414f250d15535da9f616a468b3.tar.gz |
Merge pull request #340 from thruflo/master
Add optional `redis.connection.BlockingConnectionPool` class.
-rw-r--r-- | redis/__init__.py | 1 | ||||
-rw-r--r-- | redis/_compat.py | 27 | ||||
-rw-r--r-- | redis/connection.py | 164 | ||||
-rw-r--r-- | tests/__init__.py | 2 | ||||
-rw-r--r-- | tests/connection_pool.py | 76 |
5 files changed, 269 insertions, 1 deletions
diff --git a/redis/__init__.py b/redis/__init__.py index 4b10834..5d9f3f1 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,5 +1,6 @@ from redis.client import Redis, StrictRedis from redis.connection import ( + BlockingConnectionPool, ConnectionPool, Connection, UnixDomainSocketConnection diff --git a/redis/_compat.py b/redis/_compat.py index b04098c..4a8a14c 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -48,3 +48,30 @@ else: unicode = str bytes = bytes long = int + +try: # Python 3 + from queue import LifoQueue, Empty, Full +except ImportError: + from Queue import Empty, Full + try: # Python 2.6 - 2.7 + from Queue import LifoQueue + except ImportError: # Python 2.5 + from Queue import Queue + # From the Python 2.7 lib. Python 2.5 already extracted the core + # methods to aid implementating different queue organisations. + class LifoQueue(Queue): + """Override queue methods to implement a last-in first-out queue.""" + + def _init(self, maxsize): + self.maxsize = maxsize + self.queue = [] + + def _qsize(self, len=len): + return len(self.queue) + + def _put(self, item): + self.queue.append(item) + + def _get(self): + return self.queue.pop() + diff --git a/redis/connection.py b/redis/connection.py index 60dfde6..ccfedf6 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -4,7 +4,8 @@ import socket import sys from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, - BytesIO, nativestr, basestring) + BytesIO, nativestr, basestring, + LifoQueue, Empty, Full) from redis.exceptions import ( RedisError, ConnectionError, @@ -418,3 +419,164 @@ class ConnectionPool(object): self._in_use_connections) for connection in all_conns: connection.disconnect() + + +class BlockingConnectionPool(object): + """Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + ``:py:class: ~redis.connection.ConnectionPool`` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default + ``:py:class: ~redis.connection.ConnectionPool`` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + # Raise a ``ConnectionError`` after five seconds if a connection is + # not available. + >>> pool = BlockingConnectionPool(timeout=5) + + """ + + def __init__(self, max_connections=50, timeout=20, connection_class=None, queue_class=None, **connection_kwargs): + """Compose and assign values.""" + + # Compose. + if connection_class is None: + connection_class = Connection + if queue_class is None: + queue_class = LifoQueue + + # Assign. + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.queue_class = queue_class + self.max_connections = max_connections + self.timeout = timeout + + # Validate the ``max_connections``. With the "fill up the queue" + # algorithm we use, it must be a positive integer. + is_valid = isinstance(max_connections, int) and max_connections > 0 + if not is_valid: + raise ValueError('``max_connections`` must be a positive integer') + + # Get the current process id, so we can disconnect and reinstantiate if + # it changes. + self.pid = os.getpid() + + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + def _checkpid(self): + """Check the current process id. If it has changed, disconnect and + re-instantiate this connection pool instance. + """ + + # Get the current process id. + pid = os.getpid() + + # If it hasn't changed since we were instantiated, then we're fine, so + # just exit, remaining connected. + if self.pid == pid: + return + + # If it has changed, then disconnect and re-instantiate. + self.disconnect() + self.reinstantiate() + + def make_connection(self): + """Make a fresh connection.""" + + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + def get_connection(self, command_name, *keys, **options): + """Get a connection, blocking for ``self.timeout`` until a connection is + available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections (having + been returned to the pool after the initial ``None`` values were added) + will be returned before ``None`` values. This means we only create new + connections when we need to, i.e.: the actual number of connections will + only increase in response to demand. + """ + + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be raised + # unless handled by application code. If you want never to raise + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make a + # new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + return connection + + def release(self, connection): + """Releases the connection back to the pool.""" + + # Make sure we haven't changed process. + self._checkpid() + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # This shouldn't normally happen but might perhaps happen after a + # reinstantiation. So, we can handle the exception by not putting + # the connection back on the pool, because we definitely do not + # want to reuse it. + pass + + def disconnect(self): + """Disconnects all connections in the pool.""" + + for connection in self._connections: + connection.disconnect() + + def reinstantiate(self): + """Reinstatiate this instance within a new process with a new connection + pool set. + """ + + self.__init__(max_connections=self.max_connections, timeout=self.timeout, + connection_class=self.connection_class, + queue_class=self.queue_class, **self.connection_kwargs) + diff --git a/tests/__init__.py b/tests/__init__.py index 616bdd4..71c0179 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,6 +2,7 @@ import unittest from tests.server_commands import ServerCommandsTestCase from tests.connection_pool import ConnectionPoolTestCase +from tests.connection_pool import BlockingConnectionPoolTestCase from tests.pipeline import PipelineTestCase from tests.lock import LockTestCase from tests.pubsub import PubSubTestCase, PubSubRedisDownTestCase @@ -19,6 +20,7 @@ def all_tests(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(ServerCommandsTestCase)) suite.addTest(unittest.makeSuite(ConnectionPoolTestCase)) + suite.addTest(unittest.makeSuite(BlockingConnectionPoolTestCase)) suite.addTest(unittest.makeSuite(PipelineTestCase)) suite.addTest(unittest.makeSuite(LockTestCase)) suite.addTest(unittest.makeSuite(PubSubTestCase)) diff --git a/tests/connection_pool.py b/tests/connection_pool.py index 4f0bc0c..91d4d46 100644 --- a/tests/connection_pool.py +++ b/tests/connection_pool.py @@ -36,9 +36,85 @@ class ConnectionPoolTestCase(unittest.TestCase): c2 = pool.get_connection('_') self.assertRaises(redis.ConnectionError, pool.get_connection, '_') + def test_blocking_max_connections(self): + pool = self.get_pool(max_connections=2) + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + self.assertRaises(redis.ConnectionError, pool.get_connection, '_') + + def test_release(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.release(c1) + c2 = pool.get_connection('_') + self.assertEquals(c1, c2) + + +class BlockingConnectionPoolTestCase(unittest.TestCase): + def get_pool(self, connection_info=None, max_connections=10, timeout=20): + connection_info = connection_info or {'a': 1, 'b': 2, 'c': 3} + pool = redis.BlockingConnectionPool(connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, **connection_info) + return pool + + def test_connection_creation(self): + connection_info = {'foo': 'bar', 'biz': 'baz'} + pool = self.get_pool(connection_info=connection_info) + connection = pool.get_connection('_') + self.assertEquals(connection.kwargs, connection_info) + + def test_multiple_connections(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + self.assert_(c1 != c2) + + def test_max_connections_blocks(self): + """Getting a connection should block for until available.""" + + import time + from copy import deepcopy + from threading import Thread + + # We use a queue for cross thread communication within the unit test. + try: # Python 3 + from queue import Queue + except ImportError: + from Queue import Queue + + q = Queue() + q.put_nowait('Not yet got') + pool = self.get_pool(max_connections=2, timeout=5) + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + + target = lambda: q.put_nowait(pool.get_connection('_')) + Thread(target=target).start() + + # Blocks while non available. + time.sleep(0.05) + c3 = q.get_nowait() + self.assertEquals(c3, 'Not yet got') + + # Then got when available. + pool.release(c1) + time.sleep(0.05) + c3 = q.get_nowait() + self.assertEquals(c1, c3) + + def test_max_connections_timeout(self): + """Getting a connection raises ``ConnectionError`` after timeout.""" + + pool = self.get_pool(max_connections=2, timeout=0.1) + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + self.assertRaises(redis.ConnectionError, pool.get_connection, '_') + def test_release(self): pool = self.get_pool() c1 = pool.get_connection('_') pool.release(c1) c2 = pool.get_connection('_') self.assertEquals(c1, c2) + |