From 97c50ccebf1498212c8f8e1fa9b87c3b10c48666 Mon Sep 17 00:00:00 2001 From: Andy McCurdy Date: Thu, 27 Mar 2014 19:57:12 -0700 Subject: fixed a bunch get_message() bugs, refactored the PythonParser to be saner. still need more pubsub tests but all this stuff *seems* to be working now --- redis/client.py | 33 ++++++--- redis/connection.py | 153 ++++++++++++++++++++++++++++------------- tests/test_pubsub.py | 191 +++++++++++++++++++++++++++++++-------------------- 3 files changed, 246 insertions(+), 131 deletions(-) diff --git a/redis/client.py b/redis/client.py index 268a6d5..dd31e3f 100644 --- a/redis/client.py +++ b/redis/client.py @@ -444,13 +444,13 @@ class StrictRedis(object): """ return Lock(self, name, timeout=timeout, sleep=sleep) - def pubsub(self, shard_hint=None): + def pubsub(self, **kwargs): """ Return a Publish/Subscribe object. With this object, you can subscribe to channels and listen for messages that get published to them. """ - return PubSub(self.connection_pool, shard_hint) + return PubSub(self.connection_pool, **kwargs) #### COMMAND EXECUTION AND PROTOCOL PARSING #### def execute_command(self, *args, **options): @@ -1780,12 +1780,19 @@ class PubSub(object): def reset(self): if self.connection: self.connection.disconnect() + self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None def close(self): self.reset() + def on_connect(self, connection): + if self.channels: + self.subscribe(**self.channels) + if self.patterns: + self.psubscribe(**self.patterns) + @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" @@ -1803,6 +1810,10 @@ class PubSub(object): 'pubsub', self.shard_hint ) + # initially connect here so we don't run our callback the first + # time. It's primarily there for reconnection purposes. + self.connection.connect() + self.connection.register_connect_callback(self.on_connect) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -1816,8 +1827,6 @@ class PubSub(object): connection.connect() # resubscribe to all channels and patterns before # resending the current command - self.subscribe(**self.channels) - self.psubscribe(**self.patterns) return command(*args) def parse_response(self, block=True): @@ -1840,9 +1849,11 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - self.patterns.update(dict.fromkeys(args)) - self.patterns.update(kwargs) - return self.execute_command('PSUBSCRIBE', *iterkeys(self.patterns)) + new_patterns = {} + new_patterns.update(dict.fromkeys(args)) + new_patterns.update(kwargs) + self.patterns.update(new_patterns) + return self.execute_command('PSUBSCRIBE', *iterkeys(new_patterns)) def punsubscribe(self, *args): """ @@ -1869,9 +1880,11 @@ class PubSub(object): """ if args: args = list_or_args(args[0], args[1:]) - self.channels.update(dict.fromkeys(args)) - self.channels.update(kwargs) - return self.execute_command('SUBSCRIBE', *iterkeys(self.channels)) + new_channels = {} + new_channels.update(dict.fromkeys(args)) + new_channels.update(kwargs) + self.channels.update(new_channels) + return self.execute_command('SUBSCRIBE', *iterkeys(new_channels)) def unsubscribe(self, *args): """ diff --git a/redis/connection.py b/redis/connection.py index 80945fa..2fb5663 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -27,7 +27,6 @@ if HIREDIS_AVAILABLE: SYM_STAR = b('*') SYM_DOLLAR = b('$') SYM_CRLF = b('\r\n') -SYM_LF = b('\n') SYM_EMPTY = b('') @@ -48,13 +47,89 @@ class BaseParser(object): return ResponseError(response) +class SocketBuffer(object): + MAX_READ_LENGTH = 1000000 + + def __init__(self, socket): + self._sock = socket + self._buffer = BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read + + def _read_from_socket(self, length=None): + chunksize = 8192 + if length is None: + length = chunksize + + buf = self._buffer + buf.seek(self.bytes_written) + marker = 0 + + try: + while length > marker: + data = self._sock.recv(chunksize) + buf.write(data) + self.bytes_written += len(data) + marker += chunksize + except (socket.error, socket.timeout): + e = sys.exc_info()[1] + raise ConnectionError("Error while reading from socket: %s" % + (e.args,)) + + def read(self, length): + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + self._read_from_socket(length - self.length) + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def readline(self): + buf = self._buffer + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + return data[:-2] + + def purge(self): + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + self.purge() + self._buffer.close() + + class PythonParser(BaseParser): "Plain Python parsing class" - MAX_READ_LENGTH = 1000000 encoding = None def __init__(self): - self._fp = None + self._sock = None + self._buffer = None def __del__(self): try: @@ -64,53 +139,25 @@ class PythonParser(BaseParser): def on_connect(self, connection): "Called when the socket connects" - self._fp = connection._sock.makefile('rb') + self._sock = connection._sock + self._buffer = SocketBuffer(self._sock) if connection.decode_responses: self.encoding = connection.encoding def on_disconnect(self): "Called when the socket disconnects" - if self._fp is not None: - self._fp.close() - self._fp = None + if self._sock is not None: + self._sock.close() + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None def can_read(self): - return self._fp and self._fp._rbuf.tell() - - def read(self, length=None): - """ - Read a line from the socket if no length is specified, - otherwise read ``length`` bytes. Always strip away the newlines. - """ - try: - if length is not None: - bytes_left = length + 2 # read the line ending - if length > self.MAX_READ_LENGTH: - # apparently reading more than 1MB or so from a windows - # socket can cause MemoryErrors. See: - # https://github.com/andymccurdy/redis-py/issues/205 - # read smaller chunks at a time to work around this - try: - buf = BytesIO() - while bytes_left > 0: - read_len = min(bytes_left, self.MAX_READ_LENGTH) - buf.write(self._fp.read(read_len)) - bytes_left -= read_len - buf.seek(0) - return buf.read(length) - finally: - buf.close() - return self._fp.read(bytes_left)[:-2] - - # no length, read a full line - return self._fp.readline()[:-2] - except (socket.error, socket.timeout): - e = sys.exc_info()[1] - raise ConnectionError("Error while reading from socket: %s" % - (e.args,)) + return self._buffer and bool(self._buffer.length) def read_response(self): - response = self.read() + response = self._buffer.readline() if not response: raise ConnectionError("Socket closed on remote end") @@ -144,7 +191,7 @@ class PythonParser(BaseParser): length = int(response) if length == -1: return None - response = self.read(length) + response = self._buffer.read(length) # multi-bulk response elif byte == '*': length = int(response) @@ -161,7 +208,6 @@ class HiredisParser(BaseParser): def __init__(self): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not installed") - self._next_response = False def __del__(self): try: @@ -178,10 +224,12 @@ class HiredisParser(BaseParser): if connection.decode_responses: kwargs['encoding'] = connection.encoding self._reader = hiredis.Reader(**kwargs) + self._next_response = False def on_disconnect(self): self._sock = None self._reader = None + self._next_response = False def can_read(self): if not self._reader: @@ -213,8 +261,8 @@ class HiredisParser(BaseParser): raise ConnectionError("Socket closed on remote end") self._reader.feed(buffer) # proactively, but not conclusively, check if more data is in the - # buffer. if the data received doesn't end with \n, there's more. - if not buffer.endswith(SYM_LF): + # buffer. if the data received doesn't end with \r\n, there's more. + if not buffer.endswith(SYM_CRLF): continue response = self._reader.gets() if isinstance(response, ResponseError): @@ -256,6 +304,7 @@ class Connection(object): 'port': self.port, 'db': self.db, } + self._connect_callbacks = [] def __repr__(self): return self.description_format % self._description_args @@ -266,6 +315,12 @@ class Connection(object): except Exception: pass + def register_connect_callback(self, callback): + self._connect_callbacks.append(callback) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + def connect(self): "Connects to the Redis server if not already connected" if self._sock: @@ -284,6 +339,11 @@ class Connection(object): self.disconnect() raise + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for callback in self._connect_callbacks: + callback(self) + def _connect(self): "Create a TCP socket connection" # in 2.6+ try to use IPv6/4 compatibility, else just original code @@ -361,7 +421,8 @@ class Connection(object): "Poll the socket to see if there's data that can be read." sock = self._sock if not sock: - return False + self.connect() + sock = self._sock return bool(select([sock], [], [], 0)[0]) or self._parser.can_read() def read_response(self): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 7c81389..52825af 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,90 +1,131 @@ from __future__ import with_statement import pytest +import time import redis -from redis._compat import b, next from redis.exceptions import ConnectionError -class TestPubSub(object): +def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + now = time.time() + timeout = now + timeout + while now < timeout: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages) + if message is not None: + return message + time.sleep(0.01) + now = time.time() + return None - def test_channel_subscribe(self, r): + +def make_message(type, channel, data, pattern=None): + return { + 'type': type, + 'pattern': pattern, + 'channel': channel, + 'data': data + } + + +class TestPubSubSubscribeUnsubscribe(object): + + def test_subscribe_unsubscribe(self, r): + p = r.pubsub() + + assert p.subscribe('foo', 'bar') is None + + # should be 2 messages indicating that we've subscribed + assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + assert wait_for_message(p) == make_message('subscribe', 'bar', 2) + + assert p.unsubscribe('foo', 'bar') is None + + # should be 2 messages indicating that we've unsubscribed + assert wait_for_message(p) == make_message('unsubscribe', 'foo', 1) + assert wait_for_message(p) == make_message('unsubscribe', 'bar', 0) + + def test_pattern_subscribe_unsubscribe(self, r): + p = r.pubsub() + + assert p.psubscribe('f*', 'b*') is None + + # should be 2 messages indicating that we've subscribed + assert wait_for_message(p) == make_message('psubscribe', 'f*', 1) + assert wait_for_message(p) == make_message('psubscribe', 'b*', 2) + + assert p.punsubscribe('f*', 'b*') is None + + # should be 2 messages indicating that we've unsubscribed + assert wait_for_message(p) == make_message('punsubscribe', 'f*', 1) + assert wait_for_message(p) == make_message('punsubscribe', 'b*', 0) + + def test_resubscribe_to_channels_on_reconnection(self, r): + channels = ['foo', 'bar'] + p = r.pubsub() + + assert p.subscribe(*channels) is None + + for i, channel in enumerate(channels): + i += 1 # enumerate is 0 index, but we want 1 based indexing + assert wait_for_message(p) == make_message('subscribe', channel, i) + + # manually disconnect + p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + for i, channel in enumerate(channels): + i += 1 # enumerate is 0 index, but we want 1 based indexing + assert wait_for_message(p) == make_message('subscribe', channel, i) + + def test_resubscribe_to_patterns_on_reconnection(self, r): + patterns = ['f*', 'b*'] p = r.pubsub() - # subscribe doesn't return anything - assert p.subscribe('foo') is None - - # send a message - assert r.publish('foo', 'hello foo') == 1 - - # there should be now 2 messages in the buffer, a subscribe and the - # one we just published - assert next(p.listen()) == \ - { - 'type': 'subscribe', - 'pattern': None, - 'channel': 'foo', - 'data': 1 - } - - assert next(p.listen()) == \ - { - 'type': 'message', - 'pattern': None, - 'channel': 'foo', - 'data': b('hello foo') - } - - # unsubscribe - assert p.unsubscribe('foo') is None - - # unsubscribe message should be in the buffer - assert next(p.listen()) == \ - { - 'type': 'unsubscribe', - 'pattern': None, - 'channel': 'foo', - 'data': 0 - } - - def test_pattern_subscribe(self, r): + assert p.psubscribe(*patterns) is None + + for i, pattern in enumerate(patterns): + i += 1 # enumerate is 0 index, but we want 1 based indexing + assert wait_for_message(p) == make_message( + 'psubscribe', pattern, i) + + # manually disconnect + p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + for i, pattern in enumerate(patterns): + i += 1 # enumerate is 0 index, but we want 1 based indexing + assert wait_for_message(p) == make_message( + 'psubscribe', pattern, i) + + def test_ignore_all_subscribe_messages(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, 'foo'), + (p.unsubscribe, 'foo'), + (p.psubscribe, 'f*'), + (p.unsubscribe, 'f*'), + ) + + for func, channel in checks: + assert func(channel) is None + assert wait_for_message(p) is None + + def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() - # psubscribe doesn't return anything - assert p.psubscribe('f*') is None - - # send a message - assert r.publish('foo', 'hello foo') == 1 - - # there should be now 2 messages in the buffer, a subscribe and the - # one we just published - assert next(p.listen()) == \ - { - 'type': 'psubscribe', - 'pattern': None, - 'channel': 'f*', - 'data': 1 - } - - assert next(p.listen()) == \ - { - 'type': 'pmessage', - 'pattern': 'f*', - 'channel': 'foo', - 'data': b('hello foo') - } - - # unsubscribe - assert p.punsubscribe('f*') is None - - # unsubscribe message should be in the buffer - assert next(p.listen()) == \ - { - 'type': 'punsubscribe', - 'pattern': None, - 'channel': 'f*', - 'data': 0 - } + checks = ( + (p.subscribe, 'foo'), + (p.unsubscribe, 'foo'), + (p.psubscribe, 'f*'), + (p.unsubscribe, 'f*'), + ) + + for func, channel in checks: + assert func(channel) is None + message = wait_for_message(p, ignore_subscribe_messages=True) + assert message is None class TestPubSubRedisDown(object): -- cgit v1.2.1