summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAvital Fine <79420960+AvitalFineRedis@users.noreply.github.com>2021-08-01 13:35:55 +0300
committerGitHub <noreply@github.com>2021-08-01 13:35:55 +0300
commit6f4ee2dee63171afa4c59742a3726594894916ae (patch)
tree92029675306eca84d5e51b6b470e64b6d3be97c2
parente9c2e4574a06f240dc528d5ac20cdbeb5eb6564d (diff)
downloadredis-py-6f4ee2dee63171afa4c59742a3726594894916ae.tar.gz
zinter (#1520)
* zinter * change options in _zaggregate * skip for previous versions * flake8 * validate the aggregate value * invalid aggregation * invalid aggregation * change options to get Co-authored-by: Chayim <chayim@users.noreply.github.com>
-rwxr-xr-xredis/client.py46
-rw-r--r--tests/test_commands.py22
2 files changed, 58 insertions, 10 deletions
diff --git a/redis/client.py b/redis/client.py
index 4ab1b1e..0d8a5c2 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -595,8 +595,8 @@ class Redis:
lambda r: r and set(r) or set()
),
**string_keys_to_dict(
- 'ZPOPMAX ZPOPMIN ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE '
- 'ZREVRANGEBYSCORE', zset_score_pairs
+ 'ZPOPMAX ZPOPMIN ZINTER ZDIFF ZRANGE ZRANGEBYSCORE '
+ 'ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs
),
**string_keys_to_dict('BZPOPMIN BZPOPMAX', \
lambda r:
@@ -2959,11 +2959,28 @@ class Redis:
"Increment the score of ``value`` in sorted set ``name`` by ``amount``"
return self.execute_command('ZINCRBY', name, amount, value)
+ def zinter(self, keys, aggregate=None, withscores=False):
+ """
+ Return the intersect of multiple sorted sets specified by ``keys``.
+ With the ``aggregate`` option, it is possible to specify how the
+ results of the union are aggregated. This option defaults to SUM,
+ where the score of an element is summed across the inputs where it
+ exists. When this option is set to either MIN or MAX, the resulting
+ set will contain the minimum or maximum score of an element across
+ the inputs where it exists.
+ """
+ return self._zaggregate('ZINTER', None, keys, aggregate,
+ withscores=withscores)
+
def zinterstore(self, dest, keys, aggregate=None):
"""
- Intersect multiple sorted sets specified by ``keys`` into
- a new sorted set, ``dest``. Scores in the destination will be
- aggregated based on the ``aggregate``, or SUM if none is provided.
+ Intersect multiple sorted sets specified by ``keys`` into a new
+ sorted set, ``dest``. Scores in the destination will be aggregated
+ based on the ``aggregate``. This option defaults to SUM, where the
+ score of an element is summed across the inputs where it exists.
+ When this option is set to either MIN or MAX, the resulting set will
+ contain the minimum or maximum score of an element across the inputs
+ where it exists.
"""
return self._zaggregate('ZINTERSTORE', dest, keys, aggregate)
@@ -3253,8 +3270,12 @@ class Redis:
"""
return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate)
- def _zaggregate(self, command, dest, keys, aggregate=None):
- pieces = [command, dest, len(keys)]
+ def _zaggregate(self, command, dest, keys, aggregate=None,
+ **options):
+ pieces = [command]
+ if dest is not None:
+ pieces.append(dest)
+ pieces.append(len(keys))
if isinstance(keys, dict):
keys, weights = keys.keys(), keys.values()
else:
@@ -3264,9 +3285,14 @@ class Redis:
pieces.append(b'WEIGHTS')
pieces.extend(weights)
if aggregate:
- pieces.append(b'AGGREGATE')
- pieces.append(aggregate)
- return self.execute_command(*pieces)
+ if aggregate.upper() in ['SUM', 'MIN', 'MAX']:
+ pieces.append(b'AGGREGATE')
+ pieces.append(aggregate)
+ else:
+ raise DataError("aggregate can be sum, min or max.")
+ if options.get('withscores', False):
+ pieces.append(b'WITHSCORES')
+ return self.execute_command(*pieces, **options)
# HYPERLOGLOG COMMANDS
def pfadd(self, name, *values):
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 9682603..40c813d 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -1519,6 +1519,28 @@ class TestRedisCommands:
assert r.zlexcount('a', '-', '+') == 7
assert r.zlexcount('a', '[b', '[f') == 5
+ @skip_if_server_version_lt('6.2.0')
+ def test_zinter(self, r):
+ r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1})
+ r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2})
+ r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4})
+ assert r.zinter(['a', 'b', 'c']) == [b'a3', b'a1']
+ # invalid aggregation
+ with pytest.raises(exceptions.DataError):
+ r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True)
+ # aggregate with SUM
+ assert r.zinter(['a', 'b', 'c'], withscores=True) \
+ == [(b'a3', 8), (b'a1', 9)]
+ # aggregate with MAX
+ assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \
+ == [(b'a3', 5), (b'a1', 6)]
+ # aggregate with MIN
+ assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \
+ == [(b'a1', 1), (b'a3', 1)]
+ # with weights
+ assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \
+ == [(b'a3', 20), (b'a1', 23)]
+
def test_zinterstore_sum(self, r):
r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1})
r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2})