summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xredis/client.py47
-rw-r--r--tests/test_commands.py34
2 files changed, 78 insertions, 3 deletions
diff --git a/redis/client.py b/redis/client.py
index 7bc2156..4c05e04 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -302,6 +302,12 @@ def bool_ok(response):
return nativestr(response) == 'OK'
+def parse_zadd(response, **options):
+ if options.get('as_score'):
+ return float(response)
+ return int(response)
+
+
def parse_client_list(response, **options):
clients = []
for c in nativestr(response).splitlines():
@@ -425,7 +431,7 @@ class Redis(object):
'BITCOUNT BITPOS DECRBY DEL GEOADD GETBIT HDEL HLEN HSTRLEN '
'INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD SCARD '
'SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN SUNIONSTORE '
- 'UNLINK XACK XDEL XLEN XTRIM ZADD ZCARD ZLEXCOUNT ZREM '
+ 'UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM '
'ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE',
int
),
@@ -532,6 +538,7 @@ class Redis(object):
'XINFO GROUPS': parse_list_of_dicts,
'XINFO STREAM': parse_xinfo_stream,
'XPENDING': parse_xpending,
+ 'ZADD': parse_zadd,
'ZSCAN': parse_zscan,
}
)
@@ -2201,18 +2208,52 @@ class Redis(object):
return self.execute_command('XTRIM', name, *pieces)
# SORTED SET COMMANDS
- def zadd(self, name, mapping):
+ def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False):
"""
Set any number of element-name, score pairs to the key ``name``. Pairs
are specified as a dict of element-names keys to score values.
+
+ ``nx`` forces ZADD to only create new elements and not to update
+ scores for elements that already exist.
+
+ ``xx`` forces ZADD to only update scores of elements that already
+ exist. New elements will not be added.
+
+ ``ch`` modifies the return value to be the numbers of elements changed.
+ Changed elements include new elements that were added and elements
+ whose scores changed.
+
+ ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a
+ single element/score pair can be specified and the score is the amount
+ the existing score will be incremented by. When using this mode the
+ return value of ZADD will be the new score of the element.
+
+ The return value of ZADD varies based on the mode specified. With no
+ options, ZADD returns the number of new elements added to the sorted
+ set.
"""
if not mapping:
raise DataError("ZADD requires at least one element/score pair")
+ if nx and xx:
+ raise DataError("ZADD allows either 'nx' or 'xx', not both")
+ if incr and len(mapping) != 1:
+ raise DataError("ZADD option 'incr' only works when passing a "
+ "single element/score pair")
pieces = []
+ options = {}
+ if nx:
+ pieces.append(Token.get_token('NX'))
+ if xx:
+ pieces.append(Token.get_token('XX'))
+ if ch:
+ pieces.append(Token.get_token('CH'))
+ if incr:
+ pieces.append(Token.get_token('INCR'))
+ options['as_score'] = True
for pair in iteritems(mapping):
pieces.append(pair[1])
pieces.append(pair[0])
- return self.execute_command('ZADD', name, *pieces)
+ return self.execute_command('ZADD', name, *pieces, **options)
def zcard(self, name):
"Return the number of elements in the sorted set ``name``"
diff --git a/tests/test_commands.py b/tests/test_commands.py
index ade8ecc..c8f259c 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -979,6 +979,40 @@ class TestRedisCommands(object):
assert r.zrange('a', 0, -1, withscores=True) == \
[(b'a1', 1.0), (b'a2', 2.0), (b'a3', 3.0)]
+ # error cases
+ with pytest.raises(exceptions.DataError):
+ r.zadd('a', {})
+
+ # cannot use both nx and xx options
+ with pytest.raises(exceptions.DataError):
+ r.zadd('a', mapping, nx=True, xx=True)
+
+ # cannot use the incr options with more than one value
+ with pytest.raises(exceptions.DataError):
+ r.zadd('a', mapping, incr=True)
+
+ def test_zadd_nx(self, r):
+ assert r.zadd('a', {'a1': 1}) == 1
+ assert r.zadd('a', {'a1': 99, 'a2': 2}, nx=True) == 1
+ assert r.zrange('a', 0, -1, withscores=True) == \
+ [(b'a1', 1.0), (b'a2', 2.0)]
+
+ def test_zadd_xx(self, r):
+ assert r.zadd('a', {'a1': 1}) == 1
+ assert r.zadd('a', {'a1': 99, 'a2': 2}, xx=True) == 0
+ assert r.zrange('a', 0, -1, withscores=True) == \
+ [(b'a1', 99.0)]
+
+ def test_zadd_ch(self, r):
+ assert r.zadd('a', {'a1': 1}) == 1
+ assert r.zadd('a', {'a1': 99, 'a2': 2}, ch=True) == 2
+ assert r.zrange('a', 0, -1, withscores=True) == \
+ [(b'a2', 2.0), (b'a1', 99.0)]
+
+ def test_zadd_incr(self, r):
+ assert r.zadd('a', {'a1': 1}) == 1
+ assert r.zadd('a', {'a1': 4.5}, incr=True) == 5.5
+
def test_zcard(self, r):
r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3})
assert r.zcard('a') == 3