diff options
author | Avital Fine <79420960+AvitalFineRedis@users.noreply.github.com> | 2021-12-02 14:52:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-02 15:52:46 +0200 |
commit | 8f5c1e6b14c4c82d04a3ad141821e2fdabdd0dab (patch) | |
tree | 11099c17afc90058db5429cf4583e9cdde671d09 | |
parent | 42101fc383829bb179a266420132d3f862861972 (diff) | |
download | redis-py-8f5c1e6b14c4c82d04a3ad141821e2fdabdd0dab.tar.gz |
Aggregation loadall (#1735)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
-rw-r--r-- | redis/commands/search/aggregation.py | 14 | ||||
-rw-r--r-- | tests/test_search.py | 5 |
2 files changed, 16 insertions, 3 deletions
diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 3db5542..061e69c 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -103,6 +103,7 @@ class AggregateRequest: self._query = query self._aggregateplan = [] self._loadfields = [] + self._loadall = False self._limit = Limit() self._max = 0 self._with_schema = False @@ -116,9 +117,13 @@ class AggregateRequest: ### Parameters - - **fields**: One or more fields in the format of `@field` + - **fields**: If fields not specified, all the fields will be loaded. + Otherwise, fields should be given in the format of `@field`. """ - self._loadfields.extend(fields) + if fields: + self._loadfields.extend(fields) + else: + self._loadall = True return self def group_by(self, fields, *reducers): @@ -308,7 +313,10 @@ class AggregateRequest: if self._cursor: ret += self._cursor - if self._loadfields: + if self._loadall: + ret.append("LOAD") + ret.append("*") + elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) ret.extend(self._loadfields) diff --git a/tests/test_search.py b/tests/test_search.py index 5b6a660..1a22b66 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1054,6 +1054,11 @@ def test_aggregations_load(client): res = client.ft().aggregate(req) assert res.rows[0] == ["t2", "world"] + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + @pytest.mark.redismod def test_aggregations_apply(client): |