summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAvital Fine <79420960+AvitalFineRedis@users.noreply.github.com>2021-12-02 14:52:46 +0100
committerGitHub <noreply@github.com>2021-12-02 15:52:46 +0200
commit8f5c1e6b14c4c82d04a3ad141821e2fdabdd0dab (patch)
tree11099c17afc90058db5429cf4583e9cdde671d09
parent42101fc383829bb179a266420132d3f862861972 (diff)
downloadredis-py-8f5c1e6b14c4c82d04a3ad141821e2fdabdd0dab.tar.gz
Aggregation loadall (#1735)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
-rw-r--r--redis/commands/search/aggregation.py14
-rw-r--r--tests/test_search.py5
2 files changed, 16 insertions, 3 deletions
diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py
index 3db5542..061e69c 100644
--- a/redis/commands/search/aggregation.py
+++ b/redis/commands/search/aggregation.py
@@ -103,6 +103,7 @@ class AggregateRequest:
self._query = query
self._aggregateplan = []
self._loadfields = []
+ self._loadall = False
self._limit = Limit()
self._max = 0
self._with_schema = False
@@ -116,9 +117,13 @@ class AggregateRequest:
### Parameters
- - **fields**: One or more fields in the format of `@field`
+ - **fields**: If fields not specified, all the fields will be loaded.
+ Otherwise, fields should be given in the format of `@field`.
"""
- self._loadfields.extend(fields)
+ if fields:
+ self._loadfields.extend(fields)
+ else:
+ self._loadall = True
return self
def group_by(self, fields, *reducers):
@@ -308,7 +313,10 @@ class AggregateRequest:
if self._cursor:
ret += self._cursor
- if self._loadfields:
+ if self._loadall:
+ ret.append("LOAD")
+ ret.append("*")
+ elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
diff --git a/tests/test_search.py b/tests/test_search.py
index 5b6a660..1a22b66 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -1054,6 +1054,11 @@ def test_aggregations_load(client):
res = client.ft().aggregate(req)
assert res.rows[0] == ["t2", "world"]
+ # load all
+ req = aggregations.AggregateRequest("*").load()
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ["t1", "hello", "t2", "world"]
+
@pytest.mark.redismod
def test_aggregations_apply(client):