summaryrefslogtreecommitdiff
path: root/redis/connection.py
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2011-05-12 02:43:25 -0700
committerAndy McCurdy <andy@andymccurdy.com>2011-05-12 02:43:25 -0700
commit4d6320a1666972e90f4225bd5d26fd4afad21099 (patch)
treeda32829649e2333b386ecd119eb7263776c9f072 /redis/connection.py
parent43aa23146ad287959369161038a4e5cda985a259 (diff)
downloadredis-py-4d6320a1666972e90f4225bd5d26fd4afad21099.tar.gz
connection class completely refactored. encoding and command packing moved from client to connection. introduced concept of protocol parsers and implemented both a PythonParse and a hiredis parser. the parser class can be overridden in the __init__ of the connection if desired.
Diffstat (limited to 'redis/connection.py')
-rw-r--r--redis/connection.py261
1 files changed, 148 insertions, 113 deletions
diff --git a/redis/connection.py b/redis/connection.py
index 85b4a7c..5bca883 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -1,22 +1,117 @@
import errno
+import os
import socket
import threading
+from itertools import imap
from redis.exceptions import ConnectionError, ResponseError, InvalidResponse
+class PythonParser(object):
+ def on_connect(self, connection):
+ "Called when the socket connects"
+ self._fp = connection._sock.makefile('r')
-class BaseConnection(object):
+ def on_disconnect(self):
+ "Called when the socket disconnects"
+ self._fp.close()
+ self._fp = None
+
+ def read(self, length=None):
+ """
+ Read a line from the socket is no length is specified,
+ otherwise read ``length`` bytes. Always strip away the newlines.
+ """
+ try:
+ if length is not None:
+ return self._fp.read(length+2)[:-2]
+ return self._fp.readline()[:-2]
+ except socket.error, e:
+ raise ConnectionError("Error while reading from socket: %s" % \
+ e.args[1])
+
+ def read_response(self):
+ response = self.read()
+ if not response:
+ raise ConnectionError("Socket closed on remote end")
+
+ byte, response = response[0], response[1:]
+
+ # server returned an error
+ if byte == '-':
+ if response.startswith('ERR '):
+ response = response[4:]
+ return ResponseError(response)
+ if response.startswith('LOADING '):
+ # If we're loading the dataset into memory, kill the socket
+ # so we re-initialize (and re-SELECT) next time.
+ raise ConnectionError("Redis is loading data into memory")
+ # single value
+ elif byte == '+':
+ return response
+ # int value
+ elif byte == ':':
+ return long(response)
+ # bulk response
+ elif byte == '$':
+ length = int(response)
+ if length == -1:
+ return None
+ response = length and self.read(length) or ''
+ return response
+ # multi-bulk response
+ elif byte == '*':
+ length = int(response)
+ if length == -1:
+ return None
+ return [self.read_response() for i in xrange(length)]
+ raise InvalidResponse("Protocol Error")
+
+class HiredisParser(object):
+ def on_connect(self, connection):
+ self._sock = connection._sock
+ self._reader = hiredis.Reader(
+ protocolError=InvalidResponse,
+ replyError=ResponseError)
+
+ def on_disconnect(self):
+ self._sock = None
+ self._reader = None
+
+ def read_response(self):
+ response = self._reader.gets()
+ while response is False:
+ buffer = self._sock.recv(4096)
+ if not buffer:
+ raise ConnectionError("Socket closed on remote end")
+ self._reader.feed(buffer)
+ # if the data received doesn't end with \r\n, then there's more in
+ # the socket
+ if not buffer.endswith('\r\n'):
+ continue
+ response = self._reader.gets()
+ return response
+
+try:
+ import hiredis
+ DefaultParser = HiredisParser
+except ImportError:
+ DefaultParser = PythonParser
+
+class Connection(object):
"Manages TCP communication to and from a Redis server"
def __init__(self, host='localhost', port=6379, db=0, password=None,
- socket_timeout=None):
+ socket_timeout=None, encoding='utf-8',
+ encoding_errors='strict', parser_class=DefaultParser):
self.host = host
self.port = port
self.db = db
self.password = password
self.socket_timeout = socket_timeout
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
self._sock = None
- self._fp = None
+ self._parser = parser_class()
- def connect(self, redis_instance):
+ def connect(self):
"Connects to the Redis server if not already connected"
if self._sock:
return
@@ -34,13 +129,29 @@ class BaseConnection(object):
error_message = "Error %s connecting %s:%s. %s." % \
(e.args[0], self.host, self.port, e.args[1])
raise ConnectionError(error_message)
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
self._sock = sock
- self._fp = sock.makefile('r')
- redis_instance._setup_connection()
+ self._parser.on_connect(self)
+
+ self.setup_connection()
+
+ def setup_connection(self):
+ "Authenticates and Selects the appropriate database"
+ # if a password is specified, authenticate
+ if self.password:
+ self.send_command('AUTH', self.password)
+ if self.read_response() != 'OK':
+ raise ConnectionError('Invalid Password')
+
+ # if a database is specified, switch to it
+ if self.db:
+ self.send_command('SELECT', self.db)
+ if self.read_response() != 'OK':
+ raise ConnectionError('Invalid Database')
def disconnect(self):
"Disconnects from the Redis server"
+ self._parser.on_disconnect()
if self._sock is None:
return
try:
@@ -48,11 +159,11 @@ class BaseConnection(object):
except socket.error:
pass
self._sock = None
- self._fp = None
- def send(self, command, redis_instance):
- "Send ``command`` to the Redis server. Return the result."
- self.connect(redis_instance)
+ def _send(self, command):
+ "Send the command to the socket"
+ if not self._sock:
+ self.connect()
try:
self._sock.sendall(command)
except socket.error, e:
@@ -65,115 +176,39 @@ class BaseConnection(object):
raise ConnectionError("Error %s while writing to socket. %s." % \
(_errno, errmsg))
- def read(self, length=None):
- """
- Read a line from the socket is length is None,
- otherwise read ``length`` bytes
- """
+ def send_packed_command(self, command):
+ "Send an already packed command to the Redis server"
try:
- if length is not None:
- return self._fp.read(length)
- return self._fp.readline()
- except socket.error, e:
- self.disconnect()
- if e.args and e.args[0] == errno.EAGAIN:
- raise ConnectionError("Error while reading from socket: %s" % \
- e.args[1])
- return ''
-
-class PythonConnection(BaseConnection):
- def read_response(self, command_name, catch_errors):
- response = self.read()[:-2] # strip last two characters (\r\n)
- if not response:
+ self._send(command)
+ except ConnectionError:
+ # retry the command once in case the socket connection simply
+ # timed out
self.disconnect()
- raise ConnectionError("Socket closed on remote end")
+ # if this _send() call fails, then the error will be raised
+ self._send(command)
- # server returned a null value
- if response in ('$-1', '*-1'):
- return None
- byte, response = response[0], response[1:]
+ def send_command(self, *args):
+ "Pack and send a command to the Redis server"
+ self.send_packed_command(self.pack_command(*args))
- # server returned an error
- if byte == '-':
- if response.startswith('ERR '):
- response = response[4:]
- if response.startswith('LOADING '):
- # If we're loading the dataset into memory, kill the socket
- # so we re-initialize (and re-SELECT) next time.
- self.disconnect()
- response = response[8:]
- raise ResponseError(response)
- # single value
- elif byte == '+':
- return response
- # int value
- elif byte == ':':
- return long(response)
- # bulk response
- elif byte == '$':
- length = int(response)
- if length == -1:
- return None
- response = length and self.read(length) or ''
- self.read(2) # read the \r\n delimiter
- return response
- # multi-bulk response
- elif byte == '*':
- length = int(response)
- if length == -1:
- return None
- if not catch_errors:
- return [self.read_response(command_name, catch_errors)
- for i in range(length)]
- else:
- # for pipelines, we need to read everything,
- # including response errors. otherwise we'd
- # completely mess up the receive buffer
- data = []
- for i in range(length):
- try:
- data.append(
- self.read_response(command_name, catch_errors)
- )
- except Exception, e:
- data.append(e)
- return data
-
- raise InvalidResponse("Unknown response type for: %s" % command_name)
-
-class HiredisConnection(BaseConnection):
- def connect(self, redis_instance):
- if self._sock == None:
- self._reader = hiredis.Reader(
- protocolError=InvalidResponse,
- replyError=ResponseError)
- super(HiredisConnection, self).connect(redis_instance)
-
- def disconnect(self):
- if self._sock:
- self._reader = None
- super(HiredisConnection, self).disconnect()
-
- def read_response(self, command_name, catch_errors):
- response = self._reader.gets()
- while response is False:
- buffer = self._sock.recv(4096)
- if not buffer:
- self.disconnect()
- raise ConnectionError("Socket closed on remote end")
- self._reader.feed(buffer)
- response = self._reader.gets()
-
- if isinstance(response, ResponseError):
+ def read_response(self):
+ "Read the response from a previously sent command"
+ response = self._parser.read_response()
+ if response.__class__ == ResponseError:
raise response
-
return response
-try:
- import hiredis
- Connection = HiredisConnection
-except ImportError:
- Connection = PythonConnection
+ def encode(self, value):
+ "Return a bytestring of the value"
+ if isinstance(value, unicode):
+ return value.encode(self.encoding, self.encoding_errors)
+ return str(value)
+
+ def pack_command(self, *args):
+ "Pack a series of arguments into a value Redis command"
+ command = ['$%s\r\n%s\r\n' % (len(enc_value), enc_value)
+ for enc_value in imap(self.encode, args)]
+ return '*%s\r\n%s' % (len(command), ''.join(command))
class ConnectionPool(threading.local):
"Manages a list of connections on the local thread"