summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAvital Fine <79420960+AvitalFineRedis@users.noreply.github.com>2021-11-25 14:38:21 +0100
committerGitHub <noreply@github.com>2021-11-25 15:38:21 +0200
commit3de2e6b6b1bc061d875d36a6f40598453ce85c58 (patch)
treeb29c4664b8735b73b38e88e4cb1d880a679899ce
parent20c5f0fa4676c4f0fde778dae81c3f96078348b5 (diff)
downloadredis-py-3de2e6b6b1bc061d875d36a6f40598453ce85c58.tar.gz
Improve code coverage for aggregation tests (#1713)
-rw-r--r--redis/commands/search/aggregation.py6
-rw-r--r--tests/test_search.py322
2 files changed, 264 insertions, 64 deletions
diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py
index b391d1f..3d71329 100644
--- a/redis/commands/search/aggregation.py
+++ b/redis/commands/search/aggregation.py
@@ -345,12 +345,6 @@ class AggregateRequest(object):
self._cursor = args
return self
- def _limit_2_args(self, limit):
- if limit[1]:
- return ["LIMIT"] + [str(x) for x in limit]
- else:
- return []
-
def build_args(self):
# @foo:bar ...
ret = [self._query]
diff --git a/tests/test_search.py b/tests/test_search.py
index d1fc75f..0cba3b7 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -82,8 +82,8 @@ def createIndex(client, num_docs=100, definition=None):
try:
client.create_index(
(TextField("play", weight=5.0),
- TextField("txt"),
- NumericField("chapter")),
+ TextField("txt"),
+ NumericField("chapter")),
definition=definition,
)
except redis.ResponseError:
@@ -320,8 +320,8 @@ def test_stopwords(client):
def test_filters(client):
client.ft().create_index(
(TextField("txt"),
- NumericField("num"),
- GeoField("loc"))
+ NumericField("num"),
+ GeoField("loc"))
)
client.ft().add_document(
"doc1",
@@ -379,7 +379,7 @@ def test_payloads_with_no_content(client):
def test_sort_by(client):
client.ft().create_index(
(TextField("txt"),
- NumericField("num", sortable=True))
+ NumericField("num", sortable=True))
)
client.ft().add_document("doc1", txt="foo bar", num=1)
client.ft().add_document("doc2", txt="foo baz", num=2)
@@ -424,7 +424,7 @@ def test_example(client):
# Creating the index definition and schema
client.ft().create_index(
(TextField("title", weight=5.0),
- TextField("body"))
+ TextField("body"))
)
# Indexing a document
@@ -552,8 +552,8 @@ def test_no_index(client):
def test_partial(client):
client.ft().create_index(
(TextField("f1"),
- TextField("f2"),
- TextField("f3"))
+ TextField("f2"),
+ TextField("f3"))
)
client.ft().add_document("doc1", f1="f1_val", f2="f2_val")
client.ft().add_document("doc2", f1="f1_val", f2="f2_val")
@@ -574,8 +574,8 @@ def test_partial(client):
def test_no_create(client):
client.ft().create_index(
(TextField("f1"),
- TextField("f2"),
- TextField("f3"))
+ TextField("f2"),
+ TextField("f3"))
)
client.ft().add_document("doc1", f1="f1_val", f2="f2_val")
client.ft().add_document("doc2", f1="f1_val", f2="f2_val")
@@ -604,8 +604,8 @@ def test_no_create(client):
def test_explain(client):
client.ft().create_index(
(TextField("f1"),
- TextField("f2"),
- TextField("f3"))
+ TextField("f2"),
+ TextField("f3"))
)
res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val")
assert res
@@ -629,8 +629,8 @@ def test_summarize(client):
doc = sorted(client.ft().search(q).docs)[0]
assert "<b>Henry</b> IV" == doc.play
assert (
- "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
- == doc.txt
+ "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
+ == doc.txt
)
q = Query("king henry").paging(0, 1).summarize().highlight()
@@ -638,8 +638,8 @@ def test_summarize(client):
doc = sorted(client.ft().search(q).docs)[0]
assert "<b>Henry</b> ... " == doc.play
assert (
- "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
- == doc.txt
+ "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa
+ == doc.txt
)
@@ -812,10 +812,10 @@ def test_spell_check(client):
res = client.ft().spellcheck("lorm", include="dict")
assert len(res["lorm"]) == 3
assert (
- res["lorm"][0]["suggestion"],
- res["lorm"][1]["suggestion"],
- res["lorm"][2]["suggestion"],
- ) == ("lorem", "lore", "lorm")
+ res["lorm"][0]["suggestion"],
+ res["lorm"][1]["suggestion"],
+ res["lorm"][2]["suggestion"],
+ ) == ("lorem", "lore", "lorm")
assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0")
# test spellcheck exclude
@@ -873,7 +873,7 @@ def test_scorer(client):
)
client.ft().add_document(
"doc2",
- description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa
+ description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa
)
# default scorer is TFIDF
@@ -930,7 +930,7 @@ def test_config(client):
@pytest.mark.redismod
-def test_aggregations(client):
+def test_aggregations_groupby(client):
# Creating the index definition and schema
client.ft().create_index(
(
@@ -967,36 +967,242 @@ def test_aggregations(client):
req = aggregations.AggregateRequest("redis").group_by(
"@parent",
reducers.count(),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.count_distinct("@title"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.count_distinctish("@title"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3"
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.sum("@random_num"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "21" # 10+8+3
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.min("@random_num"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3" # min(10,8,3)
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.max("@random_num"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "10" # max(10,8,3)
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.avg("@random_num"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "7" # (10+3+8)/3
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.stddev("random_num"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "3.60555127546"
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.quantile("@random_num", 0.5),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == "10"
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
reducers.tolist("@title"),
- reducers.first_value("@title"),
- reducers.random_sample("@title", 2),
)
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[3] == ["RediSearch", "RedisAI", "RedisJson"]
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
+ reducers.first_value("@title").alias("first"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res == ['parent', 'redis', 'first', 'RediSearch']
+
+ req = aggregations.AggregateRequest("redis").group_by(
+ "@parent",
+ reducers.random_sample("@title", 2).alias("random"),
+ )
+
+ res = client.ft().aggregate(req).rows[0]
+ assert res[1] == "redis"
+ assert res[2] == "random"
+ assert len(res[3]) == 2
+ assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"]
+
+
+@pytest.mark.redismod
+def test_aggregations_sort_by_and_limit(client):
+ client.ft().create_index(
+ (
+ TextField("t1"),
+ TextField("t2"),
+ )
+ )
+
+ client.ft().client.hset("doc1", mapping={'t1': 'a', 't2': 'b'})
+ client.ft().client.hset("doc2", mapping={'t1': 'b', 't2': 'a'})
+
+ # test sort_by using SortDirection
+ req = aggregations.AggregateRequest("*") \
+ .sort_by(aggregations.Asc("@t2"), aggregations.Desc("@t1"))
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ['t2', 'a', 't1', 'b']
+ assert res.rows[1] == ['t2', 'b', 't1', 'a']
+
+ # test sort_by without SortDirection
+ req = aggregations.AggregateRequest("*") \
+ .sort_by("@t1")
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ['t1', 'a']
+ assert res.rows[1] == ['t1', 'b']
+
+ # test sort_by with max
+ req = aggregations.AggregateRequest("*") \
+ .sort_by("@t1", max=1)
+ res = client.ft().aggregate(req)
+ assert len(res.rows) == 1
+
+ # test limit
+ req = aggregations.AggregateRequest("*") \
+ .sort_by("@t1").limit(1, 1)
res = client.ft().aggregate(req)
+ assert len(res.rows) == 1
+ assert res.rows[0] == ['t1', 'b']
- res = res.rows[0]
- assert len(res) == 26
- assert "redis" == res[1]
- assert "3" == res[3]
- assert "3" == res[5]
- assert "3" == res[7]
- assert "21" == res[9]
- assert "3" == res[11]
- assert "10" == res[13]
- assert "7" == res[15]
- assert "3.60555127546" == res[17]
- assert "10" == res[19]
- assert ["RediSearch", "RedisAI", "RedisJson"] == res[21]
- assert "RediSearch" == res[23]
- assert 2 == len(res[25])
+
+@pytest.mark.redismod
+def test_aggregations_load(client):
+ client.ft().create_index(
+ (
+ TextField("t1"),
+ TextField("t2"),
+ )
+ )
+
+ client.ft().client.hset("doc1", mapping={'t1': 'hello', 't2': 'world'})
+
+ # load t1
+ req = aggregations.AggregateRequest("*").load("t1")
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ['t1', 'hello']
+
+ # load t2
+ req = aggregations.AggregateRequest("*").load("t2")
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ['t2', 'world']
+
+
+@pytest.mark.redismod
+def test_aggregations_apply(client):
+ client.ft().create_index(
+ (
+ TextField("PrimaryKey", sortable=True),
+ NumericField("CreatedDateTimeUTC", sortable=True),
+ )
+ )
+
+ client.ft().client.hset(
+ "doc1",
+ mapping={
+ 'PrimaryKey': '9::362330',
+ 'CreatedDateTimeUTC': '637387878524969984'
+ }
+ )
+ client.ft().client.hset(
+ "doc2",
+ mapping={
+ 'PrimaryKey': '9::362329',
+ 'CreatedDateTimeUTC': '637387875859270016'
+ }
+ )
+
+ req = aggregations.AggregateRequest("*") \
+ .apply(CreatedDateTimeUTC='@CreatedDateTimeUTC * 10')
+ res = client.ft().aggregate(req)
+ assert res.rows[0] == ['CreatedDateTimeUTC', '6373878785249699840']
+ assert res.rows[1] == ['CreatedDateTimeUTC', '6373878758592700416']
+
+
+@pytest.mark.redismod
+def test_aggregations_filter(client):
+ client.ft().create_index(
+ (
+ TextField("name", sortable=True),
+ NumericField("age", sortable=True),
+ )
+ )
+
+ client.ft().client.hset(
+ "doc1",
+ mapping={
+ 'name': 'bar',
+ 'age': '25'
+ }
+ )
+ client.ft().client.hset(
+ "doc2",
+ mapping={
+ 'name': 'foo',
+ 'age': '19'
+ }
+ )
+
+ req = aggregations.AggregateRequest("*") \
+ .filter("@name=='foo' && @age < 20")
+ res = client.ft().aggregate(req)
+ assert len(res.rows) == 1
+ assert res.rows[0] == ['name', 'foo', 'age', '19']
+
+ req = aggregations.AggregateRequest("*") \
+ .filter("@age > 15").sort_by("@age")
+ res = client.ft().aggregate(req)
+ assert len(res.rows) == 2
+ assert res.rows[0] == ['age', '19']
+ assert res.rows[1] == ['age', '25']
@pytest.mark.redismod
@@ -1020,25 +1226,25 @@ def test_index_definition(client):
)
assert [
- "ON",
- "JSON",
- "PREFIX",
- 2,
- "hset:",
- "henry",
- "FILTER",
- "@f1==32",
- "LANGUAGE_FIELD",
- "play",
- "LANGUAGE",
- "English",
- "SCORE_FIELD",
- "chapter",
- "SCORE",
- 0.5,
- "PAYLOAD_FIELD",
- "txt",
- ] == definition.args
+ "ON",
+ "JSON",
+ "PREFIX",
+ 2,
+ "hset:",
+ "henry",
+ "FILTER",
+ "@f1==32",
+ "LANGUAGE_FIELD",
+ "play",
+ "LANGUAGE",
+ "English",
+ "SCORE_FIELD",
+ "chapter",
+ "SCORE",
+ 0.5,
+ "PAYLOAD_FIELD",
+ "txt",
+ ] == definition.args
createIndex(client.ft(), num_docs=500, definition=definition)