diff options
author | DvirDukhan <dvir@redis.com> | 2023-01-02 16:28:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-02 16:28:53 +0200 |
commit | ed38e77050241a84c0108e47254d1804e8640531 (patch) | |
tree | 6aa3182a4f3c509820e35cbe242666e373514c1f | |
parent | f10c81ac2406fc9cacea0f2e9938910db1c751e5 (diff) | |
download | redis-py-ed38e77050241a84c0108e47254d1804e8640531.tar.gz |
Add dialect to ft aggregate (#2537)
* add dialect to aggregate request
* added test
* format
* async test
-rw-r--r-- | redis/commands/search/aggregation.py | 13 | ||||
-rw-r--r-- | tests/test_asyncio/test_search.py | 169 | ||||
-rw-r--r-- | tests/test_search.py | 30 |
3 files changed, 131 insertions, 81 deletions
diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index a171fa1..93a3d92 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -108,6 +108,7 @@ class AggregateRequest: self._with_schema = False self._verbatim = False self._cursor = [] + self._dialect = None def load(self, *fields): """ @@ -321,10 +322,22 @@ class AggregateRequest: ret.append(str(len(self._loadfields))) ret.extend(self._loadfields) + if self._dialect: + ret.extend(["DIALECT", self._dialect]) + ret.extend(self._aggregateplan) return ret + def dialect(self, dialect): + """ + Add a dialect field to the aggregate command. + + - **dialect** - dialect version to execute the query under + """ + self._dialect = dialect + return self + class Cursor: def __init__(self, cid): diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 12313b6..8707cdf 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -832,100 +832,127 @@ async def test_aggregations_groupby(modclient: redis.Redis): }, ) - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + for dialect in [1, 2]: + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = (await modclient.ft().aggregate(req)).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + res = (await modclient.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] @pytest.mark.redismod diff --git a/tests/test_search.py b/tests/test_search.py index 12876f6..57d4338 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -973,16 +973,26 @@ def test_aggregations_filter(client): client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"}) client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"}) - req = aggregations.AggregateRequest("*").filter("@name=='foo' && @age < 20") - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - - req = aggregations.AggregateRequest("*").filter("@age > 15").sort_by("@age") - res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] + for dialect in [1, 2]: + req = ( + aggregations.AggregateRequest("*") + .filter("@name=='foo' && @age < 20") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["name", "foo", "age", "19"] + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] @pytest.mark.redismod |