diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2011-05-12 02:43:25 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2011-05-12 02:43:25 -0700 |
commit | 4d6320a1666972e90f4225bd5d26fd4afad21099 (patch) | |
tree | da32829649e2333b386ecd119eb7263776c9f072 /redis/connection.py | |
parent | 43aa23146ad287959369161038a4e5cda985a259 (diff) | |
download | redis-py-4d6320a1666972e90f4225bd5d26fd4afad21099.tar.gz |
connection class completely refactored. encoding and command packing moved from client to connection. introduced concept of protocol parsers and implemented both a PythonParse and a hiredis parser. the parser class can be overridden in the __init__ of the connection if desired.
Diffstat (limited to 'redis/connection.py')
-rw-r--r-- | redis/connection.py | 261 |
1 files changed, 148 insertions, 113 deletions
diff --git a/redis/connection.py b/redis/connection.py index 85b4a7c..5bca883 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,22 +1,117 @@ import errno +import os import socket import threading +from itertools import imap from redis.exceptions import ConnectionError, ResponseError, InvalidResponse +class PythonParser(object): + def on_connect(self, connection): + "Called when the socket connects" + self._fp = connection._sock.makefile('r') -class BaseConnection(object): + def on_disconnect(self): + "Called when the socket disconnects" + self._fp.close() + self._fp = None + + def read(self, length=None): + """ + Read a line from the socket is no length is specified, + otherwise read ``length`` bytes. Always strip away the newlines. + """ + try: + if length is not None: + return self._fp.read(length+2)[:-2] + return self._fp.readline()[:-2] + except socket.error, e: + raise ConnectionError("Error while reading from socket: %s" % \ + e.args[1]) + + def read_response(self): + response = self.read() + if not response: + raise ConnectionError("Socket closed on remote end") + + byte, response = response[0], response[1:] + + # server returned an error + if byte == '-': + if response.startswith('ERR '): + response = response[4:] + return ResponseError(response) + if response.startswith('LOADING '): + # If we're loading the dataset into memory, kill the socket + # so we re-initialize (and re-SELECT) next time. + raise ConnectionError("Redis is loading data into memory") + # single value + elif byte == '+': + return response + # int value + elif byte == ':': + return long(response) + # bulk response + elif byte == '$': + length = int(response) + if length == -1: + return None + response = length and self.read(length) or '' + return response + # multi-bulk response + elif byte == '*': + length = int(response) + if length == -1: + return None + return [self.read_response() for i in xrange(length)] + raise InvalidResponse("Protocol Error") + +class HiredisParser(object): + def on_connect(self, connection): + self._sock = connection._sock + self._reader = hiredis.Reader( + protocolError=InvalidResponse, + replyError=ResponseError) + + def on_disconnect(self): + self._sock = None + self._reader = None + + def read_response(self): + response = self._reader.gets() + while response is False: + buffer = self._sock.recv(4096) + if not buffer: + raise ConnectionError("Socket closed on remote end") + self._reader.feed(buffer) + # if the data received doesn't end with \r\n, then there's more in + # the socket + if not buffer.endswith('\r\n'): + continue + response = self._reader.gets() + return response + +try: + import hiredis + DefaultParser = HiredisParser +except ImportError: + DefaultParser = PythonParser + +class Connection(object): "Manages TCP communication to and from a Redis server" def __init__(self, host='localhost', port=6379, db=0, password=None, - socket_timeout=None): + socket_timeout=None, encoding='utf-8', + encoding_errors='strict', parser_class=DefaultParser): self.host = host self.port = port self.db = db self.password = password self.socket_timeout = socket_timeout + self.encoding = encoding + self.encoding_errors = encoding_errors self._sock = None - self._fp = None + self._parser = parser_class() - def connect(self, redis_instance): + def connect(self): "Connects to the Redis server if not already connected" if self._sock: return @@ -34,13 +129,29 @@ class BaseConnection(object): error_message = "Error %s connecting %s:%s. %s." % \ (e.args[0], self.host, self.port, e.args[1]) raise ConnectionError(error_message) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self._sock = sock - self._fp = sock.makefile('r') - redis_instance._setup_connection() + self._parser.on_connect(self) + + self.setup_connection() + + def setup_connection(self): + "Authenticates and Selects the appropriate database" + # if a password is specified, authenticate + if self.password: + self.send_command('AUTH', self.password) + if self.read_response() != 'OK': + raise ConnectionError('Invalid Password') + + # if a database is specified, switch to it + if self.db: + self.send_command('SELECT', self.db) + if self.read_response() != 'OK': + raise ConnectionError('Invalid Database') def disconnect(self): "Disconnects from the Redis server" + self._parser.on_disconnect() if self._sock is None: return try: @@ -48,11 +159,11 @@ class BaseConnection(object): except socket.error: pass self._sock = None - self._fp = None - def send(self, command, redis_instance): - "Send ``command`` to the Redis server. Return the result." - self.connect(redis_instance) + def _send(self, command): + "Send the command to the socket" + if not self._sock: + self.connect() try: self._sock.sendall(command) except socket.error, e: @@ -65,115 +176,39 @@ class BaseConnection(object): raise ConnectionError("Error %s while writing to socket. %s." % \ (_errno, errmsg)) - def read(self, length=None): - """ - Read a line from the socket is length is None, - otherwise read ``length`` bytes - """ + def send_packed_command(self, command): + "Send an already packed command to the Redis server" try: - if length is not None: - return self._fp.read(length) - return self._fp.readline() - except socket.error, e: - self.disconnect() - if e.args and e.args[0] == errno.EAGAIN: - raise ConnectionError("Error while reading from socket: %s" % \ - e.args[1]) - return '' - -class PythonConnection(BaseConnection): - def read_response(self, command_name, catch_errors): - response = self.read()[:-2] # strip last two characters (\r\n) - if not response: + self._send(command) + except ConnectionError: + # retry the command once in case the socket connection simply + # timed out self.disconnect() - raise ConnectionError("Socket closed on remote end") + # if this _send() call fails, then the error will be raised + self._send(command) - # server returned a null value - if response in ('$-1', '*-1'): - return None - byte, response = response[0], response[1:] + def send_command(self, *args): + "Pack and send a command to the Redis server" + self.send_packed_command(self.pack_command(*args)) - # server returned an error - if byte == '-': - if response.startswith('ERR '): - response = response[4:] - if response.startswith('LOADING '): - # If we're loading the dataset into memory, kill the socket - # so we re-initialize (and re-SELECT) next time. - self.disconnect() - response = response[8:] - raise ResponseError(response) - # single value - elif byte == '+': - return response - # int value - elif byte == ':': - return long(response) - # bulk response - elif byte == '$': - length = int(response) - if length == -1: - return None - response = length and self.read(length) or '' - self.read(2) # read the \r\n delimiter - return response - # multi-bulk response - elif byte == '*': - length = int(response) - if length == -1: - return None - if not catch_errors: - return [self.read_response(command_name, catch_errors) - for i in range(length)] - else: - # for pipelines, we need to read everything, - # including response errors. otherwise we'd - # completely mess up the receive buffer - data = [] - for i in range(length): - try: - data.append( - self.read_response(command_name, catch_errors) - ) - except Exception, e: - data.append(e) - return data - - raise InvalidResponse("Unknown response type for: %s" % command_name) - -class HiredisConnection(BaseConnection): - def connect(self, redis_instance): - if self._sock == None: - self._reader = hiredis.Reader( - protocolError=InvalidResponse, - replyError=ResponseError) - super(HiredisConnection, self).connect(redis_instance) - - def disconnect(self): - if self._sock: - self._reader = None - super(HiredisConnection, self).disconnect() - - def read_response(self, command_name, catch_errors): - response = self._reader.gets() - while response is False: - buffer = self._sock.recv(4096) - if not buffer: - self.disconnect() - raise ConnectionError("Socket closed on remote end") - self._reader.feed(buffer) - response = self._reader.gets() - - if isinstance(response, ResponseError): + def read_response(self): + "Read the response from a previously sent command" + response = self._parser.read_response() + if response.__class__ == ResponseError: raise response - return response -try: - import hiredis - Connection = HiredisConnection -except ImportError: - Connection = PythonConnection + def encode(self, value): + "Return a bytestring of the value" + if isinstance(value, unicode): + return value.encode(self.encoding, self.encoding_errors) + return str(value) + + def pack_command(self, *args): + "Pack a series of arguments into a value Redis command" + command = ['$%s\r\n%s\r\n' % (len(enc_value), enc_value) + for enc_value in imap(self.encode, args)] + return '*%s\r\n%s' % (len(command), ''.join(command)) class ConnectionPool(threading.local): "Manages a list of connections on the local thread" |