summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py22
-rw-r--r--tests/test_commands.py79
-rw-r--r--tests/test_connection_pool.py19
3 files changed, 74 insertions, 46 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 0ab6428..87e6301 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -53,9 +53,11 @@ def skip_unless_arch_bits(arch_bits):
reason="server is not {}-bit".format(arch_bits))
-def _get_client(cls, request, **kwargs):
+def _get_client(cls, request, single_connection_client=True, **kwargs):
redis_url = request.config.getoption("--redis-url")
client = cls.from_url(redis_url, **kwargs)
+ if single_connection_client:
+ client = client.client()
if request:
def teardown():
try:
@@ -64,31 +66,27 @@ def _get_client(cls, request, **kwargs):
# handle cases where a test disconnected a client
# just manually retry the flushdb
client.flushdb()
+ client.close()
client.connection_pool.disconnect()
request.addfinalizer(teardown)
return client
@pytest.fixture()
-def r(request, **kwargs):
- return _get_client(redis.Redis, request, **kwargs)
+def r(request):
+ return _get_client(redis.Redis, request)
@pytest.fixture()
-def r2(request, **kwargs):
- return [
- _get_client(redis.Redis, request, **kwargs),
- _get_client(redis.Redis, request, **kwargs),
- ]
+def r2(request):
+ "A second client for tests that need multiple"
+ return _get_client(redis.Redis, request)
def _gen_cluster_mock_resp(r, response):
- mock_connection_pool = Mock()
connection = Mock()
- response = response
connection.read_response.return_value = response
- mock_connection_pool.get_connection.return_value = connection
- r.connection_pool = mock_connection_pool
+ r.connection = connection
return r
diff --git a/tests/test_commands.py b/tests/test_commands.py
index a41e1a2..931fe9c 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -102,20 +102,20 @@ class TestRedisCommands(object):
@skip_if_server_version_lt('2.6.9')
def test_client_kill(self, r, r2):
r.client_setname('redis-py-c1')
- r2[0].client_setname('redis-py-c2')
- r2[1].client_setname('redis-py-c3')
- test_clients = [client for client in r.client_list()
- if client.get('name')
- in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']]
- assert len(test_clients) == 3
+ r2.client_setname('redis-py-c2')
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 2
- resp = r.client_kill(test_clients[1].get('addr'))
- assert isinstance(resp, bool) and resp is True
+ clients_by_name = dict([(client.get('name'), client)
+ for client in clients])
- test_clients = [client for client in r.client_list()
- if client.get('name')
- in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']]
- assert len(test_clients) == 2
+ assert r.client_kill(clients_by_name['redis-py-c2'].get('addr')) is True
+
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 1
+ assert clients[0].get('name') == 'redis-py-c1'
@skip_if_server_version_lt('2.8.12')
def test_client_kill_filter_invalid_params(self, r):
@@ -132,25 +132,44 @@ class TestRedisCommands(object):
r.client_kill_filter(_type="caster")
@skip_if_server_version_lt('2.8.12')
- def test_client_kill_filter(self, r, r2):
+ def test_client_kill_filter_by_id(self, r, r2):
+ r.client_setname('redis-py-c1')
+ r2.client_setname('redis-py-c2')
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 2
+
+ clients_by_name = dict([(client.get('name'), client)
+ for client in clients])
+
+ client_2_id = clients_by_name['redis-py-c2'].get('id')
+ resp = r.client_kill_filter(_id=client_2_id)
+ assert resp == 1
+
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 1
+ assert clients[0].get('name') == 'redis-py-c1'
+
+ @skip_if_server_version_lt('2.8.12')
+ def test_client_kill_filter_by_addr(self, r, r2):
r.client_setname('redis-py-c1')
- r2[0].client_setname('redis-py-c2')
- r2[1].client_setname('redis-py-c3')
- test_clients = [client for client in r.client_list()
- if client.get('name')
- in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']]
- assert len(test_clients) == 3
-
- resp = r.client_kill_filter(_id=test_clients[1].get('id'))
- assert isinstance(resp, int) and resp == 1
-
- resp = r.client_kill_filter(addr=test_clients[2].get('addr'))
- assert isinstance(resp, int) and resp == 1
-
- test_clients = [client for client in r.client_list()
- if client.get('name')
- in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']]
- assert len(test_clients) == 1
+ r2.client_setname('redis-py-c2')
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 2
+
+ clients_by_name = dict([(client.get('name'), client)
+ for client in clients])
+
+ client_2_addr = clients_by_name['redis-py-c2'].get('addr')
+ resp = r.client_kill_filter(addr=client_2_addr)
+ assert resp == 1
+
+ clients = [client for client in r.client_list()
+ if client.get('name') in ['redis-py-c1', 'redis-py-c2']]
+ assert len(clients) == 1
+ assert clients[0].get('name') == 'redis-py-c1'
@skip_if_server_version_lt('2.6.9')
def test_client_list_after_client_setname(self, r):
diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py
index 2aea1e4..0af615f 100644
--- a/tests/test_connection_pool.py
+++ b/tests/test_connection_pool.py
@@ -6,7 +6,7 @@ import re
from threading import Thread
from redis.connection import ssl_available, to_bool
-from .conftest import skip_if_server_version_lt
+from .conftest import skip_if_server_version_lt, _get_client
class DummyConnection(object):
@@ -448,9 +448,7 @@ class TestConnection(object):
"""
with pytest.raises(redis.BusyLoadingError):
r.execute_command('DEBUG', 'ERROR', 'LOADING fake message')
- pool = r.connection_pool
- assert len(pool._available_connections) == 1
- assert not pool._available_connections[0]._sock
+ assert not r.connection._sock
@skip_if_server_version_lt('2.8.8')
def test_busy_loading_from_pipeline_immediate_command(self, r):
@@ -521,3 +519,16 @@ class TestConnection(object):
"AuthenticationError should be raised when sending the wrong password"
with pytest.raises(redis.AuthenticationError):
r.execute_command('DEBUG', 'ERROR', 'ERR invalid password')
+
+
+class TestMultiConnectionClient(object):
+ @pytest.fixture()
+ def r(self, request):
+ return _get_client(redis.Redis,
+ request,
+ single_connection_client=False)
+
+ def test_multi_connection_command(self, r):
+ assert not r.connection
+ assert r.set('a', '123')
+ assert r.get('a') == b'123'