diff options
Diffstat (limited to 'redis/connection.py')
-rw-r--r-- | redis/connection.py | 44 |
1 files changed, 22 insertions, 22 deletions
diff --git a/redis/connection.py b/redis/connection.py index 3f65ccd..648f7a1 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -3,28 +3,6 @@ import socket import threading from redis.exceptions import ConnectionError, ResponseError, InvalidResponse -class ConnectionPool(threading.local): - "Manages a list of connections on the local thread" - def __init__(self): - self.connections = {} - - def make_connection_key(self, host, port, db): - "Create a unique key for the specified host, port and db" - return '%s:%s:%s' % (host, port, db) - - def get_connection(self, host, port, db, password, socket_timeout): - "Return a specific connection for the specified host, port and db" - key = self.make_connection_key(host, port, db) - if key not in self.connections: - self.connections[key] = Connection( - host, port, db, password, socket_timeout) - return self.connections[key] - - def get_all_connections(self): - "Return a list of all connection objects the manager knows about" - return self.connections.values() - - class BaseConnection(object): "Manages TCP communication to and from a Redis server" def __init__(self, host='localhost', port=6379, db=0, password=None, @@ -190,3 +168,25 @@ try: Connection = HiredisConnection except ImportError: Connection = PythonConnection + +class ConnectionPool(threading.local): + "Manages a list of connections on the local thread" + def __init__(self, connection_class=None): + self.connections = {} + self.connection_class = connection_class or Connection + + def make_connection_key(self, host, port, db): + "Create a unique key for the specified host, port and db" + return '%s:%s:%s' % (host, port, db) + + def get_connection(self, host, port, db, password, socket_timeout): + "Return a specific connection for the specified host, port and db" + key = self.make_connection_key(host, port, db) + if key not in self.connections: + self.connections[key] = self.connection_class( + host, port, db, password, socket_timeout) + return self.connections[key] + + def get_all_connections(self): + "Return a list of all connection objects the manager knows about" + return self.connections.values() |