summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--redis/client.py9
-rw-r--r--tests/server_commands.py35
2 files changed, 28 insertions, 16 deletions
diff --git a/redis/client.py b/redis/client.py
index 6405947..41a9660 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -1203,10 +1203,15 @@ class Redis(threading.local):
'data': r[3]
}
else:
- msg = {'type': r[0], 'channel': r[1], 'data': r[2]}
- yield msg
+ msg = {
+ 'type': r[0],
+ 'pattern': None,
+ 'channel': r[1],
+ 'data': r[2]
+ }
if r[0] == 'unsubscribe' and r[2] == 0:
self.subscribed = False
+ yield msg
class Pipeline(Redis):
diff --git a/tests/server_commands.py b/tests/server_commands.py
index b6ae82b..81096e9 100644
--- a/tests/server_commands.py
+++ b/tests/server_commands.py
@@ -1047,6 +1047,9 @@ class ServerCommandsTestCase(unittest.TestCase):
channels = ('a1', 'a2', 'a3')
for c in channels:
r.subscribe(c)
+ # state variable should be flipped
+ self.assertEquals(r.subscribed, True)
+
channels_to_publish_to = channels + ('a4',)
messages_per_channel = 4
def publish():
@@ -1054,23 +1057,26 @@ class ServerCommandsTestCase(unittest.TestCase):
for c in channels_to_publish_to:
self.client.publish(c, 'a message')
time.sleep(0.01)
- t = threading.Thread(target=publish)
+ for c in channels_to_publish_to:
+ self.client.publish(c, 'unsubscribe')
+ time.sleep(0.01)
+
messages = []
- # should receive a message for each subscribe command
+ # should receive a message for each subscribe/unsubscribe command
# plus a message for each iteration of the loop * num channels
- num_messages_to_expect = len(channels) + \
+ # we hide the data messages that tell the client to unsubscribe
+ num_messages_to_expect = len(channels)*2 + \
(messages_per_channel*len(channels))
- thread_started = False
+ t = threading.Thread(target=publish)
+ t.start()
for msg in r.listen():
- if not thread_started:
- # start the thread delayed so that we are intermingling
- # publish commands with pulling messsages off the socket
- # with subscribe
- thread_started = True
- t.start()
- messages.append(msg)
- if len(messages) == num_messages_to_expect:
- break
+ if msg['data'] == 'unsubscribe':
+ r.unsubscribe(msg['channel'])
+ else:
+ messages.append(msg)
+
+ self.assertEquals(r.subscribed, False)
+ self.assertEquals(len(messages), num_messages_to_expect)
sent_types, sent_channels = {}, {}
for msg in messages:
msg_type = msg['type']
@@ -1081,9 +1087,10 @@ class ServerCommandsTestCase(unittest.TestCase):
sent_channels.setdefault(channel, 0)
sent_channels[channel] += 1
for channel in channels:
+ self.assert_(channel in sent_channels)
self.assertEquals(sent_channels[channel], messages_per_channel)
- self.assert_(channel in channels)
self.assertEquals(sent_types['subscribe'], len(channels))
+ self.assertEquals(sent_types['unsubscribe'], len(channels))
self.assertEquals(sent_types['message'],
len(channels) * messages_per_channel)