diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 22 | ||||
-rw-r--r-- | tests/test_commands.py | 79 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 19 |
3 files changed, 74 insertions, 46 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 0ab6428..87e6301 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,9 +53,11 @@ def skip_unless_arch_bits(arch_bits): reason="server is not {}-bit".format(arch_bits)) -def _get_client(cls, request, **kwargs): +def _get_client(cls, request, single_connection_client=True, **kwargs): redis_url = request.config.getoption("--redis-url") client = cls.from_url(redis_url, **kwargs) + if single_connection_client: + client = client.client() if request: def teardown(): try: @@ -64,31 +66,27 @@ def _get_client(cls, request, **kwargs): # handle cases where a test disconnected a client # just manually retry the flushdb client.flushdb() + client.close() client.connection_pool.disconnect() request.addfinalizer(teardown) return client @pytest.fixture() -def r(request, **kwargs): - return _get_client(redis.Redis, request, **kwargs) +def r(request): + return _get_client(redis.Redis, request) @pytest.fixture() -def r2(request, **kwargs): - return [ - _get_client(redis.Redis, request, **kwargs), - _get_client(redis.Redis, request, **kwargs), - ] +def r2(request): + "A second client for tests that need multiple" + return _get_client(redis.Redis, request) def _gen_cluster_mock_resp(r, response): - mock_connection_pool = Mock() connection = Mock() - response = response connection.read_response.return_value = response - mock_connection_pool.get_connection.return_value = connection - r.connection_pool = mock_connection_pool + r.connection = connection return r diff --git a/tests/test_commands.py b/tests/test_commands.py index a41e1a2..931fe9c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -102,20 +102,20 @@ class TestRedisCommands(object): @skip_if_server_version_lt('2.6.9') def test_client_kill(self, r, r2): r.client_setname('redis-py-c1') - r2[0].client_setname('redis-py-c2') - r2[1].client_setname('redis-py-c3') - test_clients = [client for client in r.client_list() - if client.get('name') - in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] - assert len(test_clients) == 3 + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 - resp = r.client_kill(test_clients[1].get('addr')) - assert isinstance(resp, bool) and resp is True + clients_by_name = dict([(client.get('name'), client) + for client in clients]) - test_clients = [client for client in r.client_list() - if client.get('name') - in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] - assert len(test_clients) == 2 + assert r.client_kill(clients_by_name['redis-py-c2'].get('addr')) is True + + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' @skip_if_server_version_lt('2.8.12') def test_client_kill_filter_invalid_params(self, r): @@ -132,25 +132,44 @@ class TestRedisCommands(object): r.client_kill_filter(_type="caster") @skip_if_server_version_lt('2.8.12') - def test_client_kill_filter(self, r, r2): + def test_client_kill_filter_by_id(self, r, r2): + r.client_setname('redis-py-c1') + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 + + clients_by_name = dict([(client.get('name'), client) + for client in clients]) + + client_2_id = clients_by_name['redis-py-c2'].get('id') + resp = r.client_kill_filter(_id=client_2_id) + assert resp == 1 + + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' + + @skip_if_server_version_lt('2.8.12') + def test_client_kill_filter_by_addr(self, r, r2): r.client_setname('redis-py-c1') - r2[0].client_setname('redis-py-c2') - r2[1].client_setname('redis-py-c3') - test_clients = [client for client in r.client_list() - if client.get('name') - in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] - assert len(test_clients) == 3 - - resp = r.client_kill_filter(_id=test_clients[1].get('id')) - assert isinstance(resp, int) and resp == 1 - - resp = r.client_kill_filter(addr=test_clients[2].get('addr')) - assert isinstance(resp, int) and resp == 1 - - test_clients = [client for client in r.client_list() - if client.get('name') - in ['redis-py-c1', 'redis-py-c2', 'redis-py-c3']] - assert len(test_clients) == 1 + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 + + clients_by_name = dict([(client.get('name'), client) + for client in clients]) + + client_2_addr = clients_by_name['redis-py-c2'].get('addr') + resp = r.client_kill_filter(addr=client_2_addr) + assert resp == 1 + + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' @skip_if_server_version_lt('2.6.9') def test_client_list_after_client_setname(self, r): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2aea1e4..0af615f 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -6,7 +6,7 @@ import re from threading import Thread from redis.connection import ssl_available, to_bool -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, _get_client class DummyConnection(object): @@ -448,9 +448,7 @@ class TestConnection(object): """ with pytest.raises(redis.BusyLoadingError): r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') - pool = r.connection_pool - assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not r.connection._sock @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline_immediate_command(self, r): @@ -521,3 +519,16 @@ class TestConnection(object): "AuthenticationError should be raised when sending the wrong password" with pytest.raises(redis.AuthenticationError): r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') + + +class TestMultiConnectionClient(object): + @pytest.fixture() + def r(self, request): + return _get_client(redis.Redis, + request, + single_connection_client=False) + + def test_multi_connection_command(self, r): + assert not r.connection + assert r.set('a', '123') + assert r.get('a') == b'123' |