summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/client.py')
-rw-r--r--redis/client.py45
1 files changed, 31 insertions, 14 deletions
diff --git a/redis/client.py b/redis/client.py
index 83c4f32..f97bc88 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -10,14 +10,15 @@ _connection_manager = threading.local()
class Connection(object):
"Manages TCP communication to and from a Redis server"
- def __init__(self, host='localhost', port=6379, db=0):
+ def __init__(self, host='localhost', port=6379, db=0, password=None):
self.host = host
self.port = port
self.db = db
+ self.password = password
self._sock = None
self._fp = None
- def connect(self):
+ def connect(self, redis_instance):
"Connects to the Redis server is not already connected"
if self._sock:
return
@@ -30,6 +31,7 @@ class Connection(object):
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
self._sock = sock
self._fp = sock.makefile('r')
+ redis_instance._setup_connection()
def disconnect(self):
"Disconnects from the Redis server"
@@ -42,9 +44,9 @@ class Connection(object):
self._sock = None
self._fp = None
- def send(self, command):
+ def send(self, command, redis_instance):
"Send ``command`` to the Redis server. Return the result."
- self.connect()
+ self.connect(redis_instance)
try:
self._sock.sendall(command)
except socket.error, e:
@@ -196,7 +198,7 @@ class Redis(object):
#### COMMAND EXECUTION AND PROTOCOL PARSING ####
def _execute_command(self, command_name, command, **options):
- self.connection.send(command)
+ self.connection.send(command, self)
return self.parse_response(command_name, **options)
def execute_command(self, command_name, command, **options):
@@ -292,12 +294,31 @@ class Redis(object):
)
#### CONNECTION HANDLING ####
- def get_connection(self, host, port, db):
+ def get_connection(self, host, port, db, password):
"Returns a connection object"
key = '%s:%s:%s' % (host, port, db)
if not hasattr(_connection_manager, key):
- setattr(_connection_manager, key, Connection(host, port, db))
- return getattr(_connection_manager, key)
+ setattr(
+ _connection_manager,
+ key,
+ Connection(host, port, db, password))
+ conn = getattr(_connection_manager, key)
+ # if for whatever reason the connection gets a bad password, make
+ # sure a subsequent attempt with the right password makes its way
+ # to the connection
+ conn.password = password
+ return conn
+
+ def _setup_connection(self):
+ """
+ After successfully opening a socket to the Redis server, the
+ connection object calls this method to authenticate and select
+ the appropriate database.
+ """
+ if self.connection.password:
+ if not self.auth(self.connection.password):
+ raise AuthenticationError("Invalid Password")
+ self.format_inline('SELECT', self.connection.db)
def select(self, host, port, db, password=None):
"""
@@ -307,11 +328,7 @@ class Redis(object):
prior to issuing the SELECT command. This makes sure we protect
the thread-safe connections
"""
- self.connection = self.get_connection(host, port, db)
- if password:
- if not self.auth(password):
- raise AuthenticationError("Invalid Password")
- return self.format_inline('SELECT', db)
+ self.connection = self.get_connection(host, port, db, password)
#### SERVER INFORMATION ####
def auth(self, password):
@@ -866,7 +883,7 @@ class Pipeline(Redis):
def _execute(self, commands):
for _, command, options in commands:
- self.connection.send(command)
+ self.connection.send(command, self)
return [self.parse_response(name, **options)
for name, _, options in commands]