diff options
-rwxr-xr-x | redis/client.py | 47 | ||||
-rw-r--r-- | tests/test_commands.py | 34 |
2 files changed, 78 insertions, 3 deletions
diff --git a/redis/client.py b/redis/client.py index 7bc2156..4c05e04 100755 --- a/redis/client.py +++ b/redis/client.py @@ -302,6 +302,12 @@ def bool_ok(response): return nativestr(response) == 'OK' +def parse_zadd(response, **options): + if options.get('as_score'): + return float(response) + return int(response) + + def parse_client_list(response, **options): clients = [] for c in nativestr(response).splitlines(): @@ -425,7 +431,7 @@ class Redis(object): 'BITCOUNT BITPOS DECRBY DEL GEOADD GETBIT HDEL HLEN HSTRLEN ' 'INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD ' 'SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE ' - 'UNLINK XACK XDEL XLEN XTRIM ZADD ZCARD ZLEXCOUNT ZREM ' + 'UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM ' 'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE', int ), @@ -532,6 +538,7 @@ class Redis(object): 'XINFO GROUPS': parse_list_of_dicts, 'XINFO STREAM': parse_xinfo_stream, 'XPENDING': parse_xpending, + 'ZADD': parse_zadd, 'ZSCAN': parse_zscan, } ) @@ -2201,18 +2208,52 @@ class Redis(object): return self.execute_command('XTRIM', name, *pieces) # SORTED SET COMMANDS - def zadd(self, name, mapping): + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. + + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. + + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. """ if not mapping: raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError("ZADD option 'incr' only works when passing a " + "single element/score pair") pieces = [] + options = {} + if nx: + pieces.append(Token.get_token('NX')) + if xx: + pieces.append(Token.get_token('XX')) + if ch: + pieces.append(Token.get_token('CH')) + if incr: + pieces.append(Token.get_token('INCR')) + options['as_score'] = True for pair in iteritems(mapping): pieces.append(pair[1]) pieces.append(pair[0]) - return self.execute_command('ZADD', name, *pieces) + return self.execute_command('ZADD', name, *pieces, **options) def zcard(self, name): "Return the number of elements in the sorted set ``name``" diff --git a/tests/test_commands.py b/tests/test_commands.py index ade8ecc..c8f259c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -979,6 +979,40 @@ class TestRedisCommands(object): assert r.zrange('a', 0, -1, withscores=True) == \ [(b'a1', 1.0), (b'a2', 2.0), (b'a3', 3.0)] + # error cases + with pytest.raises(exceptions.DataError): + r.zadd('a', {}) + + # cannot use both nx and xx options + with pytest.raises(exceptions.DataError): + r.zadd('a', mapping, nx=True, xx=True) + + # cannot use the incr options with more than one value + with pytest.raises(exceptions.DataError): + r.zadd('a', mapping, incr=True) + + def test_zadd_nx(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, nx=True) == 1 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a1', 1.0), (b'a2', 2.0)] + + def test_zadd_xx(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, xx=True) == 0 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a1', 99.0)] + + def test_zadd_ch(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 99, 'a2': 2}, ch=True) == 2 + assert r.zrange('a', 0, -1, withscores=True) == \ + [(b'a2', 2.0), (b'a1', 99.0)] + + def test_zadd_incr(self, r): + assert r.zadd('a', {'a1': 1}) == 1 + assert r.zadd('a', {'a1': 4.5}, incr=True) == 5.5 + def test_zcard(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zcard('a') == 3 |