diff options
author | Avital Fine <79420960+AvitalFineRedis@users.noreply.github.com> | 2021-08-01 13:35:55 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-01 13:35:55 +0300 |
commit | 6f4ee2dee63171afa4c59742a3726594894916ae (patch) | |
tree | 92029675306eca84d5e51b6b470e64b6d3be97c2 | |
parent | e9c2e4574a06f240dc528d5ac20cdbeb5eb6564d (diff) | |
download | redis-py-6f4ee2dee63171afa4c59742a3726594894916ae.tar.gz |
zinter (#1520)
* zinter
* change options in _zaggregate
* skip for previous versions
* flake8
* validate the aggregate value
* invalid aggregation
* invalid aggregation
* change options to get
Co-authored-by: Chayim <chayim@users.noreply.github.com>
-rwxr-xr-x | redis/client.py | 46 | ||||
-rw-r--r-- | tests/test_commands.py | 22 |
2 files changed, 58 insertions, 10 deletions
diff --git a/redis/client.py b/redis/client.py index 4ab1b1e..0d8a5c2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -595,8 +595,8 @@ class Redis: lambda r: r and set(r) or set() ), **string_keys_to_dict( - 'ZPOPMAX ZPOPMIN ZDIFF ZRANGE ZRANGEBYSCORE ZREVRANGE ' - 'ZREVRANGEBYSCORE', zset_score_pairs + 'ZPOPMAX ZPOPMIN ZINTER ZDIFF ZRANGE ZRANGEBYSCORE ' + 'ZREVRANGE ZREVRANGEBYSCORE', zset_score_pairs ), **string_keys_to_dict('BZPOPMIN BZPOPMAX', \ lambda r: @@ -2959,11 +2959,28 @@ class Redis: "Increment the score of ``value`` in sorted set ``name`` by ``amount``" return self.execute_command('ZINCRBY', name, amount, value) + def zinter(self, keys, aggregate=None, withscores=False): + """ + Return the intersect of multiple sorted sets specified by ``keys``. + With the ``aggregate`` option, it is possible to specify how the + results of the union are aggregated. This option defaults to SUM, + where the score of an element is summed across the inputs where it + exists. When this option is set to either MIN or MAX, the resulting + set will contain the minimum or maximum score of an element across + the inputs where it exists. + """ + return self._zaggregate('ZINTER', None, keys, aggregate, + withscores=withscores) + def zinterstore(self, dest, keys, aggregate=None): """ - Intersect multiple sorted sets specified by ``keys`` into - a new sorted set, ``dest``. Scores in the destination will be - aggregated based on the ``aggregate``, or SUM if none is provided. + Intersect multiple sorted sets specified by ``keys`` into a new + sorted set, ``dest``. Scores in the destination will be aggregated + based on the ``aggregate``. This option defaults to SUM, where the + score of an element is summed across the inputs where it exists. + When this option is set to either MIN or MAX, the resulting set will + contain the minimum or maximum score of an element across the inputs + where it exists. """ return self._zaggregate('ZINTERSTORE', dest, keys, aggregate) @@ -3253,8 +3270,12 @@ class Redis: """ return self._zaggregate('ZUNIONSTORE', dest, keys, aggregate) - def _zaggregate(self, command, dest, keys, aggregate=None): - pieces = [command, dest, len(keys)] + def _zaggregate(self, command, dest, keys, aggregate=None, + **options): + pieces = [command] + if dest is not None: + pieces.append(dest) + pieces.append(len(keys)) if isinstance(keys, dict): keys, weights = keys.keys(), keys.values() else: @@ -3264,9 +3285,14 @@ class Redis: pieces.append(b'WEIGHTS') pieces.extend(weights) if aggregate: - pieces.append(b'AGGREGATE') - pieces.append(aggregate) - return self.execute_command(*pieces) + if aggregate.upper() in ['SUM', 'MIN', 'MAX']: + pieces.append(b'AGGREGATE') + pieces.append(aggregate) + else: + raise DataError("aggregate can be sum, min or max.") + if options.get('withscores', False): + pieces.append(b'WITHSCORES') + return self.execute_command(*pieces, **options) # HYPERLOGLOG COMMANDS def pfadd(self, name, *values): diff --git a/tests/test_commands.py b/tests/test_commands.py index 9682603..40c813d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1519,6 +1519,28 @@ class TestRedisCommands: assert r.zlexcount('a', '-', '+') == 7 assert r.zlexcount('a', '[b', '[f') == 5 + @skip_if_server_version_lt('6.2.0') + def test_zinter(self, r): + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1}) + r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinter(['a', 'b', 'c']) == [b'a3', b'a1'] + # invalid aggregation + with pytest.raises(exceptions.DataError): + r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True) + # aggregate with SUM + assert r.zinter(['a', 'b', 'c'], withscores=True) \ + == [(b'a3', 8), (b'a1', 9)] + # aggregate with MAX + assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ + == [(b'a3', 5), (b'a1', 6)] + # aggregate with MIN + assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ + == [(b'a1', 1), (b'a3', 1)] + # with weights + assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ + == [(b'a3', 20), (b'a1', 23)] + def test_zinterstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) |