diff options
-rw-r--r-- | redis/client.py | 9 | ||||
-rw-r--r-- | tests/server_commands.py | 35 |
2 files changed, 28 insertions, 16 deletions
diff --git a/redis/client.py b/redis/client.py index 6405947..41a9660 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1203,10 +1203,15 @@ class Redis(threading.local): 'data': r[3] } else: - msg = {'type': r[0], 'channel': r[1], 'data': r[2]} - yield msg + msg = { + 'type': r[0], + 'pattern': None, + 'channel': r[1], + 'data': r[2] + } if r[0] == 'unsubscribe' and r[2] == 0: self.subscribed = False + yield msg class Pipeline(Redis): diff --git a/tests/server_commands.py b/tests/server_commands.py index b6ae82b..81096e9 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -1047,6 +1047,9 @@ class ServerCommandsTestCase(unittest.TestCase): channels = ('a1', 'a2', 'a3') for c in channels: r.subscribe(c) + # state variable should be flipped + self.assertEquals(r.subscribed, True) + channels_to_publish_to = channels + ('a4',) messages_per_channel = 4 def publish(): @@ -1054,23 +1057,26 @@ class ServerCommandsTestCase(unittest.TestCase): for c in channels_to_publish_to: self.client.publish(c, 'a message') time.sleep(0.01) - t = threading.Thread(target=publish) + for c in channels_to_publish_to: + self.client.publish(c, 'unsubscribe') + time.sleep(0.01) + messages = [] - # should receive a message for each subscribe command + # should receive a message for each subscribe/unsubscribe command # plus a message for each iteration of the loop * num channels - num_messages_to_expect = len(channels) + \ + # we hide the data messages that tell the client to unsubscribe + num_messages_to_expect = len(channels)*2 + \ (messages_per_channel*len(channels)) - thread_started = False + t = threading.Thread(target=publish) + t.start() for msg in r.listen(): - if not thread_started: - # start the thread delayed so that we are intermingling - # publish commands with pulling messsages off the socket - # with subscribe - thread_started = True - t.start() - messages.append(msg) - if len(messages) == num_messages_to_expect: - break + if msg['data'] == 'unsubscribe': + r.unsubscribe(msg['channel']) + else: + messages.append(msg) + + self.assertEquals(r.subscribed, False) + self.assertEquals(len(messages), num_messages_to_expect) sent_types, sent_channels = {}, {} for msg in messages: msg_type = msg['type'] @@ -1081,9 +1087,10 @@ class ServerCommandsTestCase(unittest.TestCase): sent_channels.setdefault(channel, 0) sent_channels[channel] += 1 for channel in channels: + self.assert_(channel in sent_channels) self.assertEquals(sent_channels[channel], messages_per_channel) - self.assert_(channel in channels) self.assertEquals(sent_types['subscribe'], len(channels)) + self.assertEquals(sent_types['unsubscribe'], len(channels)) self.assertEquals(sent_types['message'], len(channels) * messages_per_channel) |