From 6f4ee2dee63171afa4c59742a3726594894916ae Mon Sep 17 00:00:00 2001 From: Avital Fine <79420960+AvitalFineRedis@users.noreply.github.com> Date: Sun, 1 Aug 2021 13:35:55 +0300 Subject: zinter (#1520) * zinter * change options in _zaggregate * skip for previous versions * flake8 * validate the aggregate value * invalid aggregation * invalid aggregation * change options to get Co-authored-by: Chayim --- redis/client.py | 46 ++++++++++++++++++++++++++++++++++++---------- tests/test_commands.py | 22 ++++++++++++++++++++++ 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/redis/client.py b/redis/client.py index 4ab1b1e..0d8a5c2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -595,8 +595,8 @@ class Redis: lambda r: r and set(r) or set() ), **string_keys_to_dict( - 'ZPOPMAX ZPOPMIN ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE ' - 'ZREVRANGEBYSCORE', zset_score_pairs + 'ZPOPMAX ZPOPMIN ZINTER ZDIFF ZRANGE ZRANGEBYSCORE ' + 'ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs ), **string_keys_to_dict('BZPOPMIN BZPOPMAX', \ lambda r: @@ -2959,11 +2959,28 @@ class Redis: "Increment the score of ``value`` in sorted set ``name`` by ``amount``" return self.execute_command('ZINCRBY', name, amount, value) + def zinter(self, keys, aggregate=None, withscores=False): + """ + Return the intersect of multiple sorted sets specified by ``keys``. + With the ``aggregate`` option, it is possible to specify how the + results of the union are aggregated. This option defaults to SUM, + where the score of an element is summed across the inputs where it + exists. When this option is set to either MIN or MAX, the resulting + set will contain the minimum or maximum score of an element across + the inputs where it exists. + """ + return self._zaggregate('ZINTER', None, keys, aggregate, + withscores=withscores) + def zinterstore(self, dest, keys, aggregate=None): """ - Intersect multiple sorted sets specified by ``keys`` into - a new sorted set, ``dest``. Scores in the destination will be - aggregated based on the ``aggregate``, or SUM if none is provided. + Intersect multiple sorted sets specified by ``keys`` into a new + sorted set, ``dest``. Scores in the destination will be aggregated + based on the ``aggregate``. This option defaults to SUM, where the + score of an element is summed across the inputs where it exists. + When this option is set to either MIN or MAX, the resulting set will + contain the minimum or maximum score of an element across the inputs + where it exists. """ return self._zaggregate('ZINTERSTORE', dest, keys, aggregate) @@ -3253,8 +3270,12 @@ class Redis: """ return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate) - def _zaggregate(self, command, dest, keys, aggregate=None): - pieces = [command, dest, len(keys)] + def _zaggregate(self, command, dest, keys, aggregate=None, + **options): + pieces = [command] + if dest is not None: + pieces.append(dest) + pieces.append(len(keys)) if isinstance(keys, dict): keys, weights = keys.keys(), keys.values() else: @@ -3264,9 +3285,14 @@ class Redis: pieces.append(b'WEIGHTS') pieces.extend(weights) if aggregate: - pieces.append(b'AGGREGATE') - pieces.append(aggregate) - return self.execute_command(*pieces) + if aggregate.upper() in ['SUM', 'MIN', 'MAX']: + pieces.append(b'AGGREGATE') + pieces.append(aggregate) + else: + raise DataError("aggregate can be sum, min or max.") + if options.get('withscores', False): + pieces.append(b'WITHSCORES') + return self.execute_command(*pieces, **options) # HYPERLOGLOG COMMANDS def pfadd(self, name, *values): diff --git a/tests/test_commands.py b/tests/test_commands.py index 9682603..40c813d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1519,6 +1519,28 @@ class TestRedisCommands: assert r.zlexcount('a', '-', '+') == 7 assert r.zlexcount('a', '[b', '[f') == 5 + @skip_if_server_version_lt('6.2.0') + def test_zinter(self, r): + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinter(['a', 'b', 'c']) == [b'a3', b'a1'] + # invalid aggregation + with pytest.raises(exceptions.DataError): + r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True) + # aggregate with SUM + assert r.zinter(['a', 'b', 'c'], withscores=True) \ + == [(b'a3', 8), (b'a1', 9)] + # aggregate with MAX + assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ + == [(b'a3', 5), (b'a1', 6)] + # aggregate with MIN + assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ + == [(b'a1', 1), (b'a3', 1)] + # with weights + assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ + == [(b'a3', 20), (b'a1', 23)] + def test_zinterstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) -- cgit v1.2.1