From ae71200de265a9353982789c3871dd8329a8d265 Mon Sep 17 00:00:00 2001 From: Andy McCurdy Date: Mon, 23 May 2011 10:55:40 -0700 Subject: new pubsub tests --- redis/client.py | 2 +- redis/connection.py | 2 +- tests/__init__.py | 2 ++ tests/connection_pool.py | 88 +++++++++++++++++++----------------------------- tests/pubsub.py | 51 ++++++++++++++++++++++++++++ tests/server_commands.py | 56 ------------------------------ 6 files changed, 90 insertions(+), 111 deletions(-) create mode 100644 tests/pubsub.py diff --git a/redis/client.py b/redis/client.py index 3d0e6b7..58b02ca 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1011,7 +1011,7 @@ class PubSub(object): self.patterns = set() self.subscription_count = 0 self.subscribe_commands = set( - ('subscribe', 'psusbscribe', 'unsubscribe', 'punsubscribe') + ('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe') ) def execute_command(self, *args, **kwargs): diff --git a/redis/connection.py b/redis/connection.py index af82372..061ee3a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -262,7 +262,7 @@ class ConnectionPool(object): def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: - raise Exception("Too many connections") + raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs) diff --git a/tests/__init__.py b/tests/__init__.py index ab94711..81a48ee 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,6 +3,7 @@ from server_commands import ServerCommandsTestCase from connection_pool import ConnectionPoolTestCase from pipeline import PipelineTestCase from lock import LockTestCase +from pubsub import PubSubTestCase use_hiredis = False try: @@ -17,4 +18,5 @@ def all_tests(): suite.addTest(unittest.makeSuite(ConnectionPoolTestCase)) suite.addTest(unittest.makeSuite(PipelineTestCase)) suite.addTest(unittest.makeSuite(LockTestCase)) + suite.addTest(unittest.makeSuite(PubSubTestCase)) return suite diff --git a/tests/connection_pool.py b/tests/connection_pool.py index 6c85e81..8252cee 100644 --- a/tests/connection_pool.py +++ b/tests/connection_pool.py @@ -1,57 +1,39 @@ import redis -import threading -import time import unittest -class ConnectionPoolTestCase(unittest.TestCase): - # TODO: - # THIS TEST IS INVALID WITH THE DEFAULT CONNECTIONPOOL - # - # def test_multiple_connections(self): - # # 2 clients to the same host/port/db/pool should use the same connection - # pool = redis.ConnectionPool() - # r1 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool) - # r2 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool) - # self.assertEquals(r1.connection, r2.connection) - - # # if one of them switches, they should have - # # separate conncetion objects - # r2.select(db=10, host='localhost', port=6379) - # self.assertNotEqual(r1.connection, r2.connection) - - # conns = [r1.connection, r2.connection] - # conns.sort() - - # # but returning to the original state shares the object again - # r2.select(db=9, host='localhost', port=6379) - # self.assertEquals(r1.connection, r2.connection) - - # # the connection manager should still have just 2 connections - # mgr_conns = pool.get_all_connections() - # mgr_conns.sort() - # self.assertEquals(conns, mgr_conns) - - def test_threaded_workers(self): - # TODO: review this, does it even make sense anymore? - r = redis.Redis(host='localhost', port=6379, db=9) - r.set('a', 'foo') - r.set('b', 'bar') - - def _info_worker(): - for i in range(50): - _ = r.info() - time.sleep(0.01) - - def _keys_worker(): - for i in range(50): - _ = r.keys() - time.sleep(0.01) - - t1 = threading.Thread(target=_info_worker) - t2 = threading.Thread(target=_keys_worker) - t1.start() - t2.start() - - for i in [t1, t2]: - i.join() +class DummyConnection(object): + def __init__(self, **kwargs): + self.kwargs = kwargs +class ConnectionPoolTestCase(unittest.TestCase): + def get_pool(self, connection_info=None, max_connections=None): + connection_info = connection_info or {'a': 1, 'b': 2, 'c': 3} + pool = redis.ConnectionPool( + connection_class=DummyConnection, max_connections=max_connections, + **connection_info) + return pool + + def test_connection_creation(self): + connection_info = {'foo': 'bar', 'biz': 'baz'} + pool = self.get_pool(connection_info=connection_info) + connection = pool.get_connection('_') + self.assertEquals(connection.kwargs, connection_info) + + def test_multiple_connections(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + self.assert_(c1 != c2) + + def test_max_connections(self): + pool = self.get_pool(max_connections=2) + c1 = pool.get_connection('_') + c2 = pool.get_connection('_') + self.assertRaises(redis.ConnectionError, pool.get_connection, '_') + + def test_release(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.release(c1) + c2 = pool.get_connection('_') + self.assertEquals(c1, c2) diff --git a/tests/pubsub.py b/tests/pubsub.py new file mode 100644 index 0000000..1746ac4 --- /dev/null +++ b/tests/pubsub.py @@ -0,0 +1,51 @@ +import redis +import unittest + +class PubSubTestCase(unittest.TestCase): + def setUp(self): + self.connection_pool = redis.ConnectionPool() + self.client = redis.Redis(connection_pool=self.connection_pool) + self.pubsub = self.client.pubsub() + + def tearDown(self): + self.connection_pool.disconnect() + + def test_channel_subscribe(self): + self.assertEquals( + self.pubsub.subscribe('foo'), + ['subscribe', 'foo', 1] + ) + self.assertEquals(self.client.publish('foo', 'hello foo'), 1) + self.assertEquals( + self.pubsub.listen().next(), + { + 'type': 'message', + 'pattern': None, + 'channel': 'foo', + 'data': 'hello foo' + } + ) + self.assertEquals( + self.pubsub.unsubscribe('foo'), + ['unsubscribe', 'foo', 0] + ) + + def test_pattern_subscribe(self): + self.assertEquals( + self.pubsub.psubscribe('fo*'), + ['psubscribe', 'fo*', 1] + ) + self.assertEquals(self.client.publish('foo', 'hello foo'), 1) + self.assertEquals( + self.pubsub.listen().next(), + { + 'type': 'pmessage', + 'pattern': 'fo*', + 'channel': 'foo', + 'data': 'hello foo' + } + ) + self.assertEquals( + self.pubsub.punsubscribe('fo*'), + ['punsubscribe', 'fo*', 0] + ) diff --git a/tests/server_commands.py b/tests/server_commands.py index 1501441..361a954 100644 --- a/tests/server_commands.py +++ b/tests/server_commands.py @@ -1,7 +1,6 @@ import redis import unittest import datetime -import threading import time from distutils.version import StrictVersion @@ -1194,61 +1193,6 @@ class ServerCommandsTestCase(unittest.TestCase): self.assertEquals(self.client.lrange('sorted', 0, 10), ['vodka', 'milk', 'gin', 'apple juice']) - # PUBSUB - # def test_pubsub(self): - # # create a new client to not polute the existing one - # r = self.get_client() - # channels = ('a1', 'a2', 'a3') - # for c in channels: - # r.subscribe(c) - # # state variable should be flipped - # self.assertEquals(r.subscribed, True) - - # channels_to_publish_to = channels + ('a4',) - # messages_per_channel = 4 - # def publish(): - # for i in range(messages_per_channel): - # for c in channels_to_publish_to: - # self.client.publish(c, 'a message') - # time.sleep(0.01) - # for c in channels_to_publish_to: - # self.client.publish(c, 'unsubscribe') - # time.sleep(0.01) - - # messages = [] - # self.assertRaises(redis.PubSubError, r.set, 'foo', 'bar') - # # should receive a message for each subscribe/unsubscribe command - # # plus a message for each iteration of the loop * num channels - # # we hide the data messages that tell the client to unsubscribe - # num_messages_to_expect = len(channels)*2 + \ - # (messages_per_channel*len(channels)) - # t = threading.Thread(target=publish) - # t.start() - # for msg in r.listen(): - # if msg['data'] == 'unsubscribe': - # r.unsubscribe(msg['channel']) - # else: - # messages.append(msg) - - # self.assertEquals(r.subscribed, False) - # self.assertEquals(len(messages), num_messages_to_expect) - # sent_types, sent_channels = {}, {} - # for msg in messages: - # msg_type = msg['type'] - # channel = msg['channel'] - # sent_types.setdefault(msg_type, 0) - # sent_types[msg_type] += 1 - # if msg_type == 'message': - # sent_channels.setdefault(channel, 0) - # sent_channels[channel] += 1 - # for channel in channels: - # self.assert_(channel in sent_channels) - # self.assertEquals(sent_channels[channel], messages_per_channel) - # self.assertEquals(sent_types['subscribe'], len(channels)) - # self.assertEquals(sent_types['unsubscribe'], len(channels)) - # self.assertEquals(sent_types['message'], - # len(channels) * messages_per_channel) - ## BINARY SAFE # TODO add more tests def test_binary_get_set(self): -- cgit v1.2.1