diff options
-rwxr-xr-x | redis/client.py | 71 | ||||
-rw-r--r-- | tests/test_commands.py | 27 | ||||
-rw-r--r-- | tests/test_pipeline.py | 28 |
3 files changed, 110 insertions, 16 deletions
diff --git a/redis/client.py b/redis/client.py index 621e8f8..8173ba8 100755 --- a/redis/client.py +++ b/redis/client.py @@ -25,6 +25,7 @@ from redis.exceptions import ( ) SYM_EMPTY = b'' +EMPTY_RESPONSE = 'EMPTY_RESPONSE' def list_or_args(keys, args): @@ -419,11 +420,11 @@ class StrictRedis(object): bool ), string_keys_to_dict( - 'BITCOUNT BITPOS DECRBY DEL GETBIT HDEL HLEN HSTRLEN INCRBY ' - 'LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD SDIFFSTORE ' - 'SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ZADD ZCARD ' - 'ZLEXCOUNT ZREM ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE ' - 'GEOADD XACK XDEL XLEN XTRIM', + 'BITCOUNT BITPOS DECRBY DEL GEOADD GETBIT HDEL HLEN HSTRLEN ' + 'INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD ' + 'SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ' + 'UNLINK XACK XDEL XLEN XTRIM ZADD ZCARD ZLEXCOUNT ZREM ' + 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', int ), string_keys_to_dict( @@ -439,7 +440,7 @@ class StrictRedis(object): string_keys_to_dict('ZSCORE ZINCRBY GEODIST', float_or_none), string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE RENAME ' - 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', + 'SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH', bool_ok ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), @@ -459,9 +460,11 @@ class StrictRedis(object): string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), { 'CLIENT GETNAME': lambda r: r and nativestr(r), + 'CLIENT ID': int, 'CLIENT KILL': bool_ok, 'CLIENT LIST': parse_client_list, 'CLIENT SETNAME': bool_ok, + 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, 'CLUSTER ADDSLOTS': bool_ok, 'CLUSTER COUNT-FAILURE-REPORTS': lambda x: int(x), 'CLUSTER COUNTKEYSINSLOT': lambda x: int(x), @@ -756,7 +759,12 @@ class StrictRedis(object): def parse_response(self, connection, command_name, **options): "Parses a response from the Redis server" - response = connection.read_response() + try: + response = connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response @@ -785,10 +793,26 @@ class StrictRedis(object): "Returns the current connection name" return self.execute_command('CLIENT GETNAME') + def client_id(self): + "Returns the current connection id" + return self.execute_command('CLIENT ID') + def client_setname(self, name): "Sets the current connection name" return self.execute_command('CLIENT SETNAME', name) + def client_unblock(self, client_id, error=False): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(Token.get_token('ERROR')) + return self.execute_command(*args) + def config_get(self, pattern="*"): "Return a dictionary of configuration based on the ``pattern``" return self.execute_command('CONFIG GET', pattern) @@ -825,6 +849,10 @@ class StrictRedis(object): "Delete all keys in the current database" return self.execute_command('FLUSHDB') + def swapdb(self, first, second): + "Swap two databases" + return self.execute_command('SWAPDB', first, second) + def info(self, section=None): """ Returns a dictionary containing information about the Redis server @@ -1115,7 +1143,10 @@ class StrictRedis(object): Returns a list of values ordered identically to ``keys`` """ args = list_or_args(keys, args) - return self.execute_command('MGET', *args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command('MGET', *args, **options) def mset(self, *args, **kwargs): """ @@ -1323,6 +1354,10 @@ class StrictRedis(object): warnings.warn( DeprecationWarning('Call UNWATCH from a Pipeline object')) + def unlink(self, *names): + "Unlink one or more keys specified by ``names``" + return self.execute_command('UNLINK', *names) + # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1862,7 +1897,7 @@ class StrictRedis(object): """ pieces = ['XGROUP CREATE', name, groupname, id] if mkstream: - pieces.append('MKSTREAM') + pieces.append(Token.get_token('MKSTREAM')) return self.execute_command(*pieces) def xgroup_delconsumer(self, name, groupname, consumername): @@ -3206,7 +3241,8 @@ class BasePipeline(object): def _execute_transaction(self, connection, commands, raise_on_error): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) - all_cmds = connection.pack_commands([args for args, _ in cmds]) + all_cmds = connection.pack_commands([args for args, options in cmds + if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) errors = [] @@ -3221,12 +3257,15 @@ class BasePipeline(object): # and all the other commands for i, command in enumerate(commands): - try: - self.parse_response(connection, '_') - except ResponseError: - ex = sys.exc_info()[1] - self.annotate_exception(ex, i + 1, command[0]) - errors.append((i, ex)) + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + self.parse_response(connection, '_') + except ResponseError: + ex = sys.exc_info()[1] + self.annotate_exception(ex, i + 1, command[0]) + errors.append((i, ex)) # parse the EXEC. try: diff --git a/tests/test_commands.py b/tests/test_commands.py index 2b4f634..f0394d7 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -67,6 +67,17 @@ class TestRedisCommands(object): assert isinstance(clients[0], dict) assert 'addr' in clients[0] + @skip_if_server_version_lt('5.0.0') + def test_client_id(self, r): + assert r.client_id() > 0 + + @skip_if_server_version_lt('5.0.0') + def test_client_unblock(self, r): + myid = r.client_id() + assert not r.client_unblock(myid) + assert not r.client_unblock(myid, error=True) + assert not r.client_unblock(myid, error=False) + @skip_if_server_version_lt('2.6.9') def test_client_getname(self, r): assert r.client_getname() is None @@ -292,6 +303,21 @@ class TestRedisCommands(object): del r['a'] assert r.get('a') is None + @skip_if_server_version_lt('4.0.0') + def test_unlink(self, r): + assert r.unlink('a') == 0 + r['a'] = 'foo' + assert r.unlink('a') == 1 + assert r.get('a') is None + + @skip_if_server_version_lt('4.0.0') + def test_unlink_with_multiple_keys(self, r): + r['a'] = 'foo' + r['b'] = 'bar' + assert r.unlink('a', 'b') == 2 + assert r.get('a') is None + assert r.get('b') is None + @skip_if_server_version_lt('2.6.0') def test_dump_and_restore(self, r): r['a'] = 'foo' @@ -427,6 +453,7 @@ class TestRedisCommands(object): assert set(r.keys(pattern='test*')) == keys def test_mget(self, r): + assert r.mget([]) == [] assert r.mget(['a', 'b']) == [None, None] r['a'] = '1' r['b'] = '2' diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1f1ae40..a8941d7 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -114,6 +114,34 @@ class TestPipeline(object): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + with r.pipeline() as pipe: + pipe.set('a', 1).mget([]).set('c', 3) + result = pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + with r.pipeline(transaction=False) as pipe: + pipe.set('a', 1).mget([]).set('c', 3) + result = pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + def test_parse_error_raised(self, r): with r.pipeline() as pipe: # the zrem is invalid because we don't pass any keys to it |