diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-03-13 16:05:16 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-03-13 16:05:16 -0700 |
commit | 81c2402b1438517e53d2d85abf3b5bc326f9bb82 (patch) | |
tree | 4aa6ebd748d432098784abe5bbf7ab1a1ce0270b /redis/connection.py | |
parent | 15327afcbcc7f8f66e2d3a95ff8925db04f4a90d (diff) | |
parent | a5dbfc1e3ff246dfd3188223e3e11c59da6f3e2d (diff) | |
download | redis-py-81c2402b1438517e53d2d85abf3b5bc326f9bb82.tar.gz |
Merge branch 'master' into pubsub
Conflicts:
redis/connection.py
Diffstat (limited to 'redis/connection.py')
-rw-r--r-- | redis/connection.py | 180 |
1 files changed, 125 insertions, 55 deletions
diff --git a/redis/connection.py b/redis/connection.py index d949fef..80945fa 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,46 +1,58 @@ +from __future__ import with_statement +from itertools import chain +from select import select import os import socket import sys -from itertools import chain -from select import select +import threading + from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long, BytesIO, nativestr, basestring, LifoQueue, Empty, Full) from redis.exceptions import ( RedisError, ConnectionError, + BusyLoadingError, ResponseError, InvalidResponse, AuthenticationError, NoScriptError, ExecAbortError, ) - -try: +from redis.utils import HIREDIS_AVAILABLE +if HIREDIS_AVAILABLE: import hiredis - hiredis_available = True -except ImportError: - hiredis_available = False SYM_STAR = b('*') SYM_DOLLAR = b('$') SYM_CRLF = b('\r\n') SYM_LF = b('\n') +SYM_EMPTY = b('') -class PythonParser(object): - "Plain Python parsing class" - MAX_READ_LENGTH = 1000000 - encoding = None - +class BaseParser(object): EXCEPTION_CLASSES = { 'ERR': ResponseError, 'EXECABORT': ExecAbortError, - 'LOADING': ConnectionError, + 'LOADING': BusyLoadingError, 'NOSCRIPT': NoScriptError, } + def parse_error(self, response): + "Parse an error response" + error_code = response.split(' ')[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1:] + return self.EXCEPTION_CLASSES[error_code](response) + return ResponseError(response) + + +class PythonParser(BaseParser): + "Plain Python parsing class" + MAX_READ_LENGTH = 1000000 + encoding = None + def __init__(self): self._fp = None @@ -97,14 +109,6 @@ class PythonParser(object): raise ConnectionError("Error while reading from socket: %s" % (e.args,)) - def parse_error(self, response): - "Parse an error response" - error_code = response.split(' ')[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1:] - return self.EXCEPTION_CLASSES[error_code](response) - return ResponseError(response) - def read_response(self): response = self.read() if not response: @@ -113,14 +117,22 @@ class PythonParser(object): byte, response = byte_to_chr(response[0]), response[1:] if byte not in ('-', '+', ':', '$', '*'): - raise InvalidResponse("Protocol Error") + raise InvalidResponse("Protocol Error: %s, %s" % + (str(byte), str(response))) # server returned an error if byte == '-': response = nativestr(response) - # *return*, not raise the exception class. if it is meant to be - # raised, it will be at a higher level. - return self.parse_error(response) + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error # single value elif byte == '+': pass @@ -144,10 +156,10 @@ class PythonParser(object): return response -class HiredisParser(object): +class HiredisParser(BaseParser): "Parser class for connections using Hiredis" def __init__(self): - if not hiredis_available: + if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") self._next_response = False @@ -205,9 +217,16 @@ class HiredisParser(object): if not buffer.endswith(SYM_LF): continue response = self._reader.gets() + if isinstance(response, ResponseError): + response = self.parse_error(response.args[0]) + # hiredis only knows about ResponseErrors. + # self.parse_error() might turn the exception into a ConnectionError + # which needs raising. + if isinstance(response, ConnectionError): + raise response return response -if hiredis_available: +if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: DefaultParser = PythonParser @@ -215,6 +234,8 @@ else: class Connection(object): "Manages TCP communication to and from a Redis server" + description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>" + def __init__(self, host='localhost', port=6379, db=0, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, @@ -230,6 +251,14 @@ class Connection(object): self.decode_responses = decode_responses self._sock = None self._parser = parser_class() + self._description_args = { + 'host': self.host, + 'port': self.port, + 'db': self.db, + } + + def __repr__(self): + return self.description_format % self._description_args def __del__(self): try: @@ -248,13 +277,23 @@ class Connection(object): raise ConnectionError(self._error_message(e)) self._sock = sock - self.on_connect() + try: + self.on_connect() + except RedisError: + # clean up after any error in on_connect + self.disconnect() + raise def _connect(self): "Create a TCP socket connection" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(self.socket_timeout) - sock.connect((self.host, self.port)) + # in 2.6+ try to use IPv6/4 compatibility, else just original code + if hasattr(socket, 'create_connection'): + sock = socket.create_connection((self.host, self.port), + self.socket_timeout) + else: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(self.socket_timeout) + sock.connect((self.host, self.port)) return sock def _error_message(self, exception): @@ -289,6 +328,7 @@ class Connection(object): if self._sock is None: return try: + self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() except socket.error: pass @@ -309,7 +349,7 @@ class Connection(object): _errno, errmsg = e.args raise ConnectionError("Error %s while writing to socket. %s." % (_errno, errmsg)) - except Exception: + except: self.disconnect() raise @@ -328,7 +368,7 @@ class Connection(object): "Read the response from a previously sent command" try: response = self._parser.read_response() - except Exception: + except: self.disconnect() raise if isinstance(response, ResponseError): @@ -349,17 +389,17 @@ class Connection(object): def pack_command(self, *args): "Pack a series of arguments into a value Redis command" - output = SYM_STAR + b(str(len(args))) + SYM_CRLF - for enc_value in imap(self.encode, args): - output += SYM_DOLLAR - output += b(str(len(enc_value))) - output += SYM_CRLF - output += enc_value - output += SYM_CRLF + args_output = SYM_EMPTY.join([ + SYM_EMPTY.join((SYM_DOLLAR, b(str(len(k))), SYM_CRLF, k, SYM_CRLF)) + for k in imap(self.encode, args)]) + output = SYM_EMPTY.join( + (SYM_STAR, b(str(len(args))), SYM_CRLF, args_output)) return output class UnixDomainSocketConnection(Connection): + description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>" + def __init__(self, path='', db=0, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, @@ -374,6 +414,10 @@ class UnixDomainSocketConnection(Connection): self.decode_responses = decode_responses self._sock = None self._parser = parser_class() + self._description_args = { + 'path': self.path, + 'db': self.db, + } def _connect(self): "Create a Unix domain socket connection" @@ -393,11 +437,20 @@ class UnixDomainSocketConnection(Connection): (exception.args[0], self.path, exception.args[1]) -# TODO: add ability to block waiting on a connection to be released class ConnectionPool(object): "Generic connection pool" def __init__(self, connection_class=Connection, max_connections=None, **connection_kwargs): + """ + Create a connection pool. If max_connections is set, then this + object raises redis.ConnectionError when the pool's limit is reached. + + By default, TCP connections are created connection_class is specified. + Use redis.UnixDomainSocketConnection for unix sockets. + + Any additional keyword arguments are passed to the constructor of + connection_class. + """ self.pid = os.getpid() self.connection_class = connection_class self.connection_kwargs = connection_kwargs @@ -405,12 +458,24 @@ class ConnectionPool(object): self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() + self._check_lock = threading.Lock() + + def __repr__(self): + return "%s<%s>" % ( + type(self).__name__, + self.connection_class.description_format % self.connection_kwargs, + ) def _checkpid(self): if self.pid != os.getpid(): - self.disconnect() - self.__init__(self.connection_class, self.max_connections, - **self.connection_kwargs) + with self._check_lock: + if self.pid == os.getpid(): + # another thread already did the work while we waited + # on the lock. + return + self.disconnect() + self.__init__(self.connection_class, self.max_connections, + **self.connection_kwargs) def get_connection(self, command_name, *keys, **options): "Get a connection from the pool" @@ -502,6 +567,7 @@ class BlockingConnectionPool(object): # Get the current process id, so we can disconnect and reinstantiate if # it changes. self.pid = os.getpid() + self._check_lock = threading.Lock() # Create and fill up a thread safe queue with ``None`` values. self.pool = self.queue_class(max_connections) @@ -515,22 +581,26 @@ class BlockingConnectionPool(object): # disconnect them later. self._connections = [] + def __repr__(self): + return "%s<%s>" % ( + type(self).__name__, + self.connection_class.description_format % self.connection_kwargs, + ) + def _checkpid(self): """ Check the current process id. If it has changed, disconnect and re-instantiate this connection pool instance. """ - # Get the current process id. pid = os.getpid() - - # If it hasn't changed since we were instantiated, then we're fine, so - # just exit, remaining connected. - if self.pid == pid: - return - - # If it has changed, then disconnect and re-instantiate. - self.disconnect() - self.reinstantiate() + if self.pid != pid: + with self._check_lock: + if self.pid == os.getpid(): + # another thread already did the work while we waited + # on the lock. + return + self.disconnect() + self.reinstantiate() def make_connection(self): "Make a fresh connection." |