diff options
-rw-r--r-- | redis/__init__.py | 4 | ||||
-rw-r--r-- | redis/client.py | 9 | ||||
-rw-r--r-- | tests/connection_pool.py | 32 | ||||
-rw-r--r-- | tests/pipeline.py | 13 |
4 files changed, 25 insertions, 33 deletions
diff --git a/redis/__init__.py b/redis/__init__.py index 0090559..93155fb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,10 +1,10 @@ # legacy imports -from redis.client import Redis, connection_manager +from redis.client import Redis, ConnectionPool from redis.exceptions import RedisError, ConnectionError, AuthenticationError from redis.exceptions import ResponseError, InvalidResponse, InvalidData __all__ = [ - 'Redis', 'connection_manager', + 'Redis', 'ConnectionPool', 'RedisError', 'ConnectionError', 'ResponseError', 'AuthenticationError' 'InvalidResponse', 'InvalidData', ] diff --git a/redis/client.py b/redis/client.py index 7dc5e06..09a21fb 100644 --- a/redis/client.py +++ b/redis/client.py @@ -8,7 +8,7 @@ from redis.exceptions import ConnectionError, ResponseError, InvalidResponse from redis.exceptions import RedisError, AuthenticationError -class ConnectionManager(threading.local): +class ConnectionPool(threading.local): "Manages a list of connections on the local thread" def __init__(self): self.connections = {} @@ -30,9 +30,6 @@ class ConnectionManager(threading.local): return self.connections.values() -connection_manager = ConnectionManager() - - class Connection(object): "Manages TCP communication to and from a Redis server" def __init__(self, host='localhost', port=6379, db=0, password=None, @@ -236,11 +233,13 @@ class Redis(threading.local): def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, + connection_pool=None, charset='utf-8', errors='strict'): self.encoding = charset self.errors = errors self.connection = None self.subscribed = False + self.connection_pool = connection_pool and connection_pool or ConnectionPool() self.select(db, host, port, password, socket_timeout) #### Legacty accessors of connection information #### @@ -376,7 +375,7 @@ class Redis(threading.local): #### CONNECTION HANDLING #### def get_connection(self, host, port, db, password, socket_timeout): "Returns a connection object" - conn = connection_manager.get_connection( + conn = self.connection_pool.get_connection( host, port, db, password, socket_timeout) # if for whatever reason the connection gets a bad password, make # sure a subsequent attempt with the right password makes its way diff --git a/tests/connection_pool.py b/tests/connection_pool.py index 64e637c..56f5f43 100644 --- a/tests/connection_pool.py +++ b/tests/connection_pool.py @@ -5,49 +5,49 @@ import unittest class ConnectionPoolTestCase(unittest.TestCase): def test_multiple_connections(self): - # 2 clients to the same host/port/db should use the same connection - r1 = redis.Redis(host='localhost', port=6379, db=9) - r2 = redis.Redis(host='localhost', port=6379, db=9) + # 2 clients to the same host/port/db/pool should use the same connection + pool = redis.ConnectionPool() + r1 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool) + r2 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool) self.assertEquals(r1.connection, r2.connection) - - # if one o them switches, they should have + + # if one of them switches, they should have # separate conncetion objects r2.select(db=10, host='localhost', port=6379) self.assertNotEqual(r1.connection, r2.connection) - + conns = [r1.connection, r2.connection] conns.sort() - + # but returning to the original state shares the object again r2.select(db=9, host='localhost', port=6379) self.assertEquals(r1.connection, r2.connection) - + # the connection manager should still have just 2 connections - mgr_conns = redis.connection_manager.get_all_connections() + mgr_conns = pool.get_all_connections() mgr_conns.sort() self.assertEquals(conns, mgr_conns) - + def test_threaded_workers(self): r = redis.Redis(host='localhost', port=6379, db=9) r.set('a', 'foo') r.set('b', 'bar') - + def _info_worker(): for i in range(50): _ = r.info() time.sleep(0.01) - + def _keys_worker(): for i in range(50): _ = r.keys() time.sleep(0.01) - + t1 = threading.Thread(target=_info_worker) t2 = threading.Thread(target=_keys_worker) t1.start() t2.start() - + for i in [t1, t2]: i.join() - -
\ No newline at end of file + diff --git a/tests/pipeline.py b/tests/pipeline.py index f9c8dfa..dcbfb0a 100644 --- a/tests/pipeline.py +++ b/tests/pipeline.py @@ -24,20 +24,13 @@ class PipelineTestCase(unittest.TestCase): ] ) - def test_pipeline_with_fresh_connection(self): - redis.client.connection_manager.connections.clear() - self.client = redis.Redis(host='localhost', port=6379, db=9) - pipe = self.client.pipeline() - pipe.set('a', 'b') - self.assertEquals(pipe.execute(), [True]) - def test_invalid_command_in_pipeline(self): # all commands but the invalid one should be excuted correctly self.client['c'] = 'a' pipe = self.client.pipeline() pipe.set('a', 1).set('b', 2).lpush('c', 3).set('d', 4) result = pipe.execute() - + self.assertEquals(result[0], True) self.assertEquals(self.client['a'], '1') self.assertEquals(result[1], True) @@ -48,11 +41,11 @@ class PipelineTestCase(unittest.TestCase): self.assertEquals(self.client['c'], 'a') self.assertEquals(result[3], True) self.assertEquals(self.client['d'], '4') - + # make sure the pipe was restored to a working state self.assertEquals(pipe.set('z', 'zzz').execute(), [True]) self.assertEquals(self.client['z'], 'zzz') - + def test_pipeline_cannot_select(self): pipe = self.client.pipeline() self.assertRaises(redis.RedisError, |