summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndy McCurdy <andy@andymccurdy.com>2011-05-23 10:55:40 -0700
committerAndy McCurdy <andy@andymccurdy.com>2011-05-23 10:55:40 -0700
commitae71200de265a9353982789c3871dd8329a8d265 (patch)
tree5c9b8c1886c05dd466909fd76b6cb0eb2ba2b5e2
parentdef45c6133358a79274815e738e6c4005e10e188 (diff)
downloadredis-py-ae71200de265a9353982789c3871dd8329a8d265.tar.gz
new pubsub tests
-rw-r--r--redis/client.py2
-rw-r--r--redis/connection.py2
-rw-r--r--tests/__init__.py2
-rw-r--r--tests/connection_pool.py88
-rw-r--r--tests/pubsub.py51
-rw-r--r--tests/server_commands.py56
6 files changed, 90 insertions, 111 deletions
diff --git a/redis/client.py b/redis/client.py
index 3d0e6b7..58b02ca 100644
--- a/redis/client.py
+++ b/redis/client.py
@@ -1011,7 +1011,7 @@ class PubSub(object):
self.patterns = set()
self.subscription_count = 0
self.subscribe_commands = set(
- ('subscribe', 'psusbscribe', 'unsubscribe', 'punsubscribe')
+ ('subscribe', 'psubscribe', 'unsubscribe', 'punsubscribe')
)
def execute_command(self, *args, **kwargs):
diff --git a/redis/connection.py b/redis/connection.py
index af82372..061ee3a 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -262,7 +262,7 @@ class ConnectionPool(object):
def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
- raise Exception("Too many connections")
+ raise ConnectionError("Too many connections")
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
diff --git a/tests/__init__.py b/tests/__init__.py
index ab94711..81a48ee 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -3,6 +3,7 @@ from server_commands import ServerCommandsTestCase
from connection_pool import ConnectionPoolTestCase
from pipeline import PipelineTestCase
from lock import LockTestCase
+from pubsub import PubSubTestCase
use_hiredis = False
try:
@@ -17,4 +18,5 @@ def all_tests():
suite.addTest(unittest.makeSuite(ConnectionPoolTestCase))
suite.addTest(unittest.makeSuite(PipelineTestCase))
suite.addTest(unittest.makeSuite(LockTestCase))
+ suite.addTest(unittest.makeSuite(PubSubTestCase))
return suite
diff --git a/tests/connection_pool.py b/tests/connection_pool.py
index 6c85e81..8252cee 100644
--- a/tests/connection_pool.py
+++ b/tests/connection_pool.py
@@ -1,57 +1,39 @@
import redis
-import threading
-import time
import unittest
-class ConnectionPoolTestCase(unittest.TestCase):
- # TODO:
- # THIS TEST IS INVALID WITH THE DEFAULT CONNECTIONPOOL
- #
- # def test_multiple_connections(self):
- # # 2 clients to the same host/port/db/pool should use the same connection
- # pool = redis.ConnectionPool()
- # r1 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool)
- # r2 = redis.Redis(host='localhost', port=6379, db=9, connection_pool=pool)
- # self.assertEquals(r1.connection, r2.connection)
-
- # # if one of them switches, they should have
- # # separate conncetion objects
- # r2.select(db=10, host='localhost', port=6379)
- # self.assertNotEqual(r1.connection, r2.connection)
-
- # conns = [r1.connection, r2.connection]
- # conns.sort()
-
- # # but returning to the original state shares the object again
- # r2.select(db=9, host='localhost', port=6379)
- # self.assertEquals(r1.connection, r2.connection)
-
- # # the connection manager should still have just 2 connections
- # mgr_conns = pool.get_all_connections()
- # mgr_conns.sort()
- # self.assertEquals(conns, mgr_conns)
-
- def test_threaded_workers(self):
- # TODO: review this, does it even make sense anymore?
- r = redis.Redis(host='localhost', port=6379, db=9)
- r.set('a', 'foo')
- r.set('b', 'bar')
-
- def _info_worker():
- for i in range(50):
- _ = r.info()
- time.sleep(0.01)
-
- def _keys_worker():
- for i in range(50):
- _ = r.keys()
- time.sleep(0.01)
-
- t1 = threading.Thread(target=_info_worker)
- t2 = threading.Thread(target=_keys_worker)
- t1.start()
- t2.start()
-
- for i in [t1, t2]:
- i.join()
+class DummyConnection(object):
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+class ConnectionPoolTestCase(unittest.TestCase):
+ def get_pool(self, connection_info=None, max_connections=None):
+ connection_info = connection_info or {'a': 1, 'b': 2, 'c': 3}
+ pool = redis.ConnectionPool(
+ connection_class=DummyConnection, max_connections=max_connections,
+ **connection_info)
+ return pool
+
+ def test_connection_creation(self):
+ connection_info = {'foo': 'bar', 'biz': 'baz'}
+ pool = self.get_pool(connection_info=connection_info)
+ connection = pool.get_connection('_')
+ self.assertEquals(connection.kwargs, connection_info)
+
+ def test_multiple_connections(self):
+ pool = self.get_pool()
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+ self.assert_(c1 != c2)
+
+ def test_max_connections(self):
+ pool = self.get_pool(max_connections=2)
+ c1 = pool.get_connection('_')
+ c2 = pool.get_connection('_')
+ self.assertRaises(redis.ConnectionError, pool.get_connection, '_')
+
+ def test_release(self):
+ pool = self.get_pool()
+ c1 = pool.get_connection('_')
+ pool.release(c1)
+ c2 = pool.get_connection('_')
+ self.assertEquals(c1, c2)
diff --git a/tests/pubsub.py b/tests/pubsub.py
new file mode 100644
index 0000000..1746ac4
--- /dev/null
+++ b/tests/pubsub.py
@@ -0,0 +1,51 @@
+import redis
+import unittest
+
+class PubSubTestCase(unittest.TestCase):
+ def setUp(self):
+ self.connection_pool = redis.ConnectionPool()
+ self.client = redis.Redis(connection_pool=self.connection_pool)
+ self.pubsub = self.client.pubsub()
+
+ def tearDown(self):
+ self.connection_pool.disconnect()
+
+ def test_channel_subscribe(self):
+ self.assertEquals(
+ self.pubsub.subscribe('foo'),
+ ['subscribe', 'foo', 1]
+ )
+ self.assertEquals(self.client.publish('foo', 'hello foo'), 1)
+ self.assertEquals(
+ self.pubsub.listen().next(),
+ {
+ 'type': 'message',
+ 'pattern': None,
+ 'channel': 'foo',
+ 'data': 'hello foo'
+ }
+ )
+ self.assertEquals(
+ self.pubsub.unsubscribe('foo'),
+ ['unsubscribe', 'foo', 0]
+ )
+
+ def test_pattern_subscribe(self):
+ self.assertEquals(
+ self.pubsub.psubscribe('fo*'),
+ ['psubscribe', 'fo*', 1]
+ )
+ self.assertEquals(self.client.publish('foo', 'hello foo'), 1)
+ self.assertEquals(
+ self.pubsub.listen().next(),
+ {
+ 'type': 'pmessage',
+ 'pattern': 'fo*',
+ 'channel': 'foo',
+ 'data': 'hello foo'
+ }
+ )
+ self.assertEquals(
+ self.pubsub.punsubscribe('fo*'),
+ ['punsubscribe', 'fo*', 0]
+ )
diff --git a/tests/server_commands.py b/tests/server_commands.py
index 1501441..361a954 100644
--- a/tests/server_commands.py
+++ b/tests/server_commands.py
@@ -1,7 +1,6 @@
import redis
import unittest
import datetime
-import threading
import time
from distutils.version import StrictVersion
@@ -1194,61 +1193,6 @@ class ServerCommandsTestCase(unittest.TestCase):
self.assertEquals(self.client.lrange('sorted', 0, 10),
['vodka', 'milk', 'gin', 'apple juice'])
- # PUBSUB
- # def test_pubsub(self):
- # # create a new client to not polute the existing one
- # r = self.get_client()
- # channels = ('a1', 'a2', 'a3')
- # for c in channels:
- # r.subscribe(c)
- # # state variable should be flipped
- # self.assertEquals(r.subscribed, True)
-
- # channels_to_publish_to = channels + ('a4',)
- # messages_per_channel = 4
- # def publish():
- # for i in range(messages_per_channel):
- # for c in channels_to_publish_to:
- # self.client.publish(c, 'a message')
- # time.sleep(0.01)
- # for c in channels_to_publish_to:
- # self.client.publish(c, 'unsubscribe')
- # time.sleep(0.01)
-
- # messages = []
- # self.assertRaises(redis.PubSubError, r.set, 'foo', 'bar')
- # # should receive a message for each subscribe/unsubscribe command
- # # plus a message for each iteration of the loop * num channels
- # # we hide the data messages that tell the client to unsubscribe
- # num_messages_to_expect = len(channels)*2 + \
- # (messages_per_channel*len(channels))
- # t = threading.Thread(target=publish)
- # t.start()
- # for msg in r.listen():
- # if msg['data'] == 'unsubscribe':
- # r.unsubscribe(msg['channel'])
- # else:
- # messages.append(msg)
-
- # self.assertEquals(r.subscribed, False)
- # self.assertEquals(len(messages), num_messages_to_expect)
- # sent_types, sent_channels = {}, {}
- # for msg in messages:
- # msg_type = msg['type']
- # channel = msg['channel']
- # sent_types.setdefault(msg_type, 0)
- # sent_types[msg_type] += 1
- # if msg_type == 'message':
- # sent_channels.setdefault(channel, 0)
- # sent_channels[channel] += 1
- # for channel in channels:
- # self.assert_(channel in sent_channels)
- # self.assertEquals(sent_channels[channel], messages_per_channel)
- # self.assertEquals(sent_types['subscribe'], len(channels))
- # self.assertEquals(sent_types['unsubscribe'], len(channels))
- # self.assertEquals(sent_types['message'],
- # len(channels) * messages_per_channel)
-
## BINARY SAFE
# TODO add more tests
def test_binary_get_set(self):