summaryrefslogtreecommitdiff
path: root/redis/connection.py
diff options
context:
space:
mode:
authorPieter Noordhuis <pcnoordhuis@gmail.com>2011-01-30 15:28:25 +0100
committerPieter Noordhuis <pcnoordhuis@gmail.com>2011-01-30 15:28:25 +0100
commitabb59667e8aa140c1120edbeca57bfde69ea540a (patch)
tree6e4fd5fe9079ad38bad6617e1650cee9a0962957 /redis/connection.py
parent7d96a315ef08381649903b51d75127338421f76f (diff)
downloadredis-py-abb59667e8aa140c1120edbeca57bfde69ea540a.tar.gz
Move Connection and ConnectionPool to different file
Diffstat (limited to 'redis/connection.py')
-rw-r--r--redis/connection.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/redis/connection.py b/redis/connection.py
new file mode 100644
index 0000000..667e7c2
--- /dev/null
+++ b/redis/connection.py
@@ -0,0 +1,157 @@
+import errno
+import socket
+import threading
+from redis.exceptions import ConnectionError, ResponseError, InvalidResponse
+
+class ConnectionPool(threading.local):
+ "Manages a list of connections on the local thread"
+ def __init__(self):
+ self.connections = {}
+
+ def make_connection_key(self, host, port, db):
+ "Create a unique key for the specified host, port and db"
+ return '%s:%s:%s' % (host, port, db)
+
+ def get_connection(self, host, port, db, password, socket_timeout):
+ "Return a specific connection for the specified host, port and db"
+ key = self.make_connection_key(host, port, db)
+ if key not in self.connections:
+ self.connections[key] = Connection(
+ host, port, db, password, socket_timeout)
+ return self.connections[key]
+
+ def get_all_connections(self):
+ "Return a list of all connection objects the manager knows about"
+ return self.connections.values()
+
+
+class Connection(object):
+ "Manages TCP communication to and from a Redis server"
+ def __init__(self, host='localhost', port=6379, db=0, password=None,
+ socket_timeout=None):
+ self.host = host
+ self.port = port
+ self.db = db
+ self.password = password
+ self.socket_timeout = socket_timeout
+ self._sock = None
+ self._fp = None
+
+ def connect(self, redis_instance):
+ "Connects to the Redis server if not already connected"
+ if self._sock:
+ return
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(self.socket_timeout)
+ sock.connect((self.host, self.port))
+ except socket.error, e:
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+ if len(e.args) == 1:
+ error_message = "Error connecting to %s:%s. %s." % \
+ (self.host, self.port, e.args[0])
+ else:
+ error_message = "Error %s connecting %s:%s. %s." % \
+ (e.args[0], self.host, self.port, e.args[1])
+ raise ConnectionError(error_message)
+ sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ self._sock = sock
+ self._fp = sock.makefile('r')
+ redis_instance._setup_connection()
+
+ def disconnect(self):
+ "Disconnects from the Redis server"
+ if self._sock is None:
+ return
+ try:
+ self._sock.close()
+ except socket.error:
+ pass
+ self._sock = None
+ self._fp = None
+
+ def send(self, command, redis_instance):
+ "Send ``command`` to the Redis server. Return the result."
+ self.connect(redis_instance)
+ try:
+ self._sock.sendall(command)
+ except socket.error, e:
+ if e.args[0] == errno.EPIPE:
+ self.disconnect()
+ if isinstance(e.args, basestring):
+ _errno, errmsg = 'UNKNOWN', e.args
+ else:
+ _errno, errmsg = e.args
+ raise ConnectionError("Error %s while writing to socket. %s." % \
+ (_errno, errmsg))
+
+ def read(self, length=None):
+ """
+ Read a line from the socket is length is None,
+ otherwise read ``length`` bytes
+ """
+ try:
+ if length is not None:
+ return self._fp.read(length)
+ return self._fp.readline()
+ except socket.error, e:
+ self.disconnect()
+ if e.args and e.args[0] == errno.EAGAIN:
+ raise ConnectionError("Error while reading from socket: %s" % \
+ e.args[1])
+ return ''
+
+ def read_response(self, command_name, catch_errors):
+ response = self.read()[:-2] # strip last two characters (\r\n)
+ if not response:
+ self.disconnect()
+ raise ConnectionError("Socket closed on remote end")
+
+ # server returned a null value
+ if response in ('$-1', '*-1'):
+ return None
+ byte, response = response[0], response[1:]
+
+ # server returned an error
+ if byte == '-':
+ if response.startswith('ERR '):
+ response = response[4:]
+ raise ResponseError(response)
+ # single value
+ elif byte == '+':
+ return response
+ # int value
+ elif byte == ':':
+ return int(response)
+ # bulk response
+ elif byte == '$':
+ length = int(response)
+ if length == -1:
+ return None
+ response = length and self.read(length) or ''
+ self.read(2) # read the \r\n delimiter
+ return response
+ # multi-bulk response
+ elif byte == '*':
+ length = int(response)
+ if length == -1:
+ return None
+ if not catch_errors:
+ return [self.read_response(command_name, catch_errors)
+ for i in range(length)]
+ else:
+ # for pipelines, we need to read everything,
+ # including response errors. otherwise we'd
+ # completely mess up the receive buffer
+ data = []
+ for i in range(length):
+ try:
+ data.append(
+ self.read_response(command_name, catch_errors)
+ )
+ except Exception, e:
+ data.append(e)
+ return data
+
+ raise InvalidResponse("Unknown response type for: %s" % command_name)