diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2010-04-27 21:51:40 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2010-04-27 21:51:40 -0700 |
commit | 60063e8ef022ea500788a572d7b3c07080557c54 (patch) | |
tree | f6c1b927924740dd232520b5c3889ebe05a7ec69 | |
parent | c3d531e9d279d0d3532267b1d1887ac8e68f5a43 (diff) | |
download | redis-py-60063e8ef022ea500788a572d7b3c07080557c54.tar.gz |
added pubsub tests
fixed a typo in hmset from the previous commit
-rw-r--r-- | redis/client.py | 26 | ||||
-rw-r--r-- | tests/server_commands.py | 51 |
2 files changed, 74 insertions, 3 deletions
diff --git a/redis/client.py b/redis/client.py index 34b55dd..16e357e 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1024,7 +1024,7 @@ class Redis(threading.local): """ items = [] [items.extend(pair) for pair in mapping.iteritems()] - return self.execute_command('HMSET', key, *items) + return self.execute_command('HMSET', name, *items) def hmget(self, name, keys): "Returns a list of values ordered identically to ``keys``" @@ -1036,6 +1036,25 @@ class Redis(threading.local): # channels + def psubscribe(self, patterns): + "Subscribe to all channels matching any pattern in ``patterns``" + if isinstance(patterns, basestring): + patterns = [patterns] + response = self.execute_command('PSUBSCRIBE', *patterns) + # this is *after* the SUBSCRIBE in order to allow for lazy and broken + # connections that need to issue AUTH and SELECT commands + self.subscribed = True + return response + + def punsubscribe(self, patterns=[]): + """ + Unsubscribe from any channel matching any pattern in ``patterns``. + If empty, unsubscribe from all channels. + """ + if isinstance(patterns, basestring): + patterns = [patterns] + return self.execute_command('PUNSUBSCRIBE', *patterns) + def subscribe(self, channels): "Subscribe to ``channels``, waiting for messages to be published" if isinstance(channels, basestring): @@ -1047,7 +1066,10 @@ class Redis(threading.local): return response def unsubscribe(self, channels=[]): - "Unsubscribe to ``channels``. If empty, unsubscribe from all channels" + """ + Unsubscribe from ``channels``. If empty, unsubscribe + from all channels + """ if isinstance(channels, basestring): channels = [channels] return self.execute_command('UNSUBSCRIBE', *channels) diff --git a/tests/server_commands.py b/tests/server_commands.py index d46ee38..360b763 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -1,12 +1,17 @@ import redis import unittest import datetime +import threading +import time from distutils.version import StrictVersion class ServerCommandsTestCase(unittest.TestCase): + def get_client(self): + return redis.Redis(host='localhost', port=6379, db=9) + def setUp(self): - self.client = redis.Redis(host='localhost', port=6379, db=9) + self.client = self.get_client() self.client.flushdb() def tearDown(self): @@ -940,6 +945,50 @@ class ServerCommandsTestCase(unittest.TestCase): self.assertEquals(self.client.lrange('sorted', 0, 10), ['vodka', 'milk', 'gin', 'apple juice']) + # PUBSUB + def test_pubsub(self): + # create a new client to not polute the existing one + r = self.get_client() + channels = ('a1', 'a2', 'a3') + for c in channels: + r.subscribe(c) + channels_to_publish_to = channels + ('a4',) + messages_per_channel = 4 + def publish(): + for i in range(messages_per_channel): + for c in channels_to_publish_to: + self.client.publish(c, 'a message') + time.sleep(0.01) + t = threading.Thread(target=publish) + messages = [] + # should receive a message for each subscribe command + # plus a message for each iteration of the loop * num channels + num_messages_to_expect = len(channels) + \ + (messages_per_channel*len(channels)) + thread_started = False + for msg in r.listen(): + if not thread_started: + # start the thread delayed so that we are intermingling + # publish commands with pulling messsages off the socket + # with subscribe + thread_started = True + t.start() + messages.append(msg) + if len(messages) == num_messages_to_expect: + break + sent_types, sent_channels = {}, {} + for msg_type, channel, _ in messages: + sent_types.setdefault(msg_type, 0) + sent_types[msg_type] += 1 + if msg_type == 'message': + sent_channels.setdefault(channel, 0) + sent_channels[channel] += 1 + for channel in channels: + self.assertEquals(sent_channels[channel], messages_per_channel) + self.assert_(channel in channels) + self.assertEquals(sent_types['subscribe'], len(channels)) + self.assertEquals(sent_types['message'], + len(channels) * messages_per_channel) ## BINARY SAFE # TODO add more tests |