summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDvirDukhan <dvir@redis.com>2023-01-02 16:28:53 +0200
committerGitHub <noreply@github.com>2023-01-02 16:28:53 +0200
commited38e77050241a84c0108e47254d1804e8640531 (patch)
tree6aa3182a4f3c509820e35cbe242666e373514c1f
parentf10c81ac2406fc9cacea0f2e9938910db1c751e5 (diff)
downloadredis-py-ed38e77050241a84c0108e47254d1804e8640531.tar.gz
Add dialect to ft aggregate (#2537)
* add dialect to aggregate request * added test * format * async test
-rw-r--r--redis/commands/search/aggregation.py13
-rw-r--r--tests/test_asyncio/test_search.py169
-rw-r--r--tests/test_search.py30
3 files changed, 131 insertions, 81 deletions
diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py
index a171fa1..93a3d92 100644
--- a/redis/commands/search/aggregation.py
+++ b/redis/commands/search/aggregation.py
@@ -108,6 +108,7 @@ class AggregateRequest:
self._with_schema = False
self._verbatim = False
self._cursor = []
+ self._dialect = None
def load(self, *fields):
"""
@@ -321,10 +322,22 @@ class AggregateRequest:
ret.append(str(len(self._loadfields)))
ret.extend(self._loadfields)
+ if self._dialect:
+ ret.extend(["DIALECT", self._dialect])
+
ret.extend(self._aggregateplan)
return ret
+ def dialect(self, dialect):
+ """
+ Add a dialect field to the aggregate command.
+
+ - **dialect** - dialect version to execute the query under
+ """
+ self._dialect = dialect
+ return self
+
class Cursor:
def __init__(self, cid):
diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py
index 12313b6..8707cdf 100644
--- a/tests/test_asyncio/test_search.py
+++ b/tests/test_asyncio/test_search.py
@@ -832,100 +832,127 @@ async def test_aggregations_groupby(modclient: redis.Redis):
},
)
- req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count())
+ for dialect in [1, 2]:
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.count())
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "3"
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.count_distinct("@title")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.count_distinct("@title"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "3"
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.count_distinctish("@title")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.count_distinctish("@title"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "3"
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.sum("@random_num")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.sum("@random_num"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "21" # 10+8+3
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "21" # 10+8+3
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.min("@random_num")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.min("@random_num"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "3" # min(10,8,3)
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3" # min(10,8,3)
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.max("@random_num")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.max("@random_num"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "10" # max(10,8,3)
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "10" # max(10,8,3)
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.avg("@random_num")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.avg("@random_num"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "7" # (10+3+8)/3
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "7" # (10+3+8)/3
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.stddev("random_num")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.stddev("random_num"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "3.60555127546"
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3.60555127546"
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.quantile("@random_num", 0.5)
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.quantile("@random_num", 0.5))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[3] == "8" # median of 3,8,10
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "8" # median of 3,8,10
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.tolist("@title")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.tolist("@title"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"}
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"}
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.first_value("@title").alias("first")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.first_value("@title").alias("first"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res == ["parent", "redis", "first", "RediSearch"]
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res == ["parent", "redis", "first", "RediSearch"]
- req = aggregations.AggregateRequest("redis").group_by(
- "@parent", reducers.random_sample("@title", 2).alias("random")
- )
+ req = (
+ aggregations.AggregateRequest("redis")
+ .group_by("@parent", reducers.random_sample("@title", 2).alias("random"))
+ .dialect(dialect)
+ )
- res = (await modclient.ft().aggregate(req)).rows[0]
- assert res[1] == "redis"
- assert res[2] == "random"
- assert len(res[3]) == 2
- assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"]
+ res = (await modclient.ft().aggregate(req)).rows[0]
+ assert res[1] == "redis"
+ assert res[2] == "random"
+ assert len(res[3]) == 2
+ assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"]
@pytest.mark.redismod
diff --git a/tests/test_search.py b/tests/test_search.py
index 12876f6..57d4338 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -973,16 +973,26 @@ def test_aggregations_filter(client):
client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})
- req = aggregations.AggregateRequest("*").filter("@name=='foo' && @age < 20")
- res = client.ft().aggregate(req)
- assert len(res.rows) == 1
- assert res.rows[0] == ["name", "foo", "age", "19"]
-
- req = aggregations.AggregateRequest("*").filter("@age > 15").sort_by("@age")
- res = client.ft().aggregate(req)
- assert len(res.rows) == 2
- assert res.rows[0] == ["age", "19"]
- assert res.rows[1] == ["age", "25"]
+ for dialect in [1, 2]:
+ req = (
+ aggregations.AggregateRequest("*")
+ .filter("@name=='foo' && @age < 20")
+ .dialect(dialect)
+ )
+ res = client.ft().aggregate(req)
+ assert len(res.rows) == 1
+ assert res.rows[0] == ["name", "foo", "age", "19"]
+
+ req = (
+ aggregations.AggregateRequest("*")
+ .filter("@age > 15")
+ .sort_by("@age")
+ .dialect(dialect)
+ )
+ res = client.ft().aggregate(req)
+ assert len(res.rows) == 2
+ assert res.rows[0] == ["age", "19"]
+ assert res.rows[1] == ["age", "25"]
@pytest.mark.redismod