summaryrefslogtreecommitdiff
path: root/tests/test_search.py
diff options
context:
space:
mode:
authorAvital Fine <98389525+Avital-Fine@users.noreply.github.com>2022-03-23 13:02:46 +0100
committerGitHub <noreply@github.com>2022-03-23 14:02:46 +0200
commit019f651c822c86525e57ad85a7e19f72757255bb (patch)
tree19e08fbba017f5917fd291c57851ccda203ad585 /tests/test_search.py
parent827dcde5c0af5f7aa9bdc3999fc86aa2ba945118 (diff)
downloadredis-py-019f651c822c86525e57ad85a7e19f72757255bb.tar.gz
Support for Vector Fields for Vector Similarity Search (#2041)
* Support Vector field in FT.CREATE command * linters * fix data error * change to dic * add type hints and docstring to constructor * test not supported algorithm * linters * fix errors * example * delete example Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'tests/test_search.py')
-rw-r--r--tests/test_search.py42
1 files changed, 41 insertions, 1 deletions
diff --git a/tests/test_search.py b/tests/test_search.py
index 5dea739..b94ae05 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -12,7 +12,13 @@ import redis.commands.search.aggregation as aggregations
import redis.commands.search.reducers as reducers
from redis.commands.json.path import Path
from redis.commands.search import Search
-from redis.commands.search.field import GeoField, NumericField, TagField, TextField
+from redis.commands.search.field import (
+ GeoField,
+ NumericField,
+ TagField,
+ TextField,
+ VectorField,
+)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import GeoFilter, NumericFilter, Query
from redis.commands.search.result import Result
@@ -1523,6 +1529,40 @@ def test_profile_limited(client):
@pytest.mark.redismod
+def test_vector_field(modclient):
+ modclient.flushdb()
+ modclient.ft().create_index(
+ (
+ VectorField(
+ "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}
+ ),
+ )
+ )
+ modclient.hset("a", "v", "aaaaaaaa")
+ modclient.hset("b", "v", "aaaabaaa")
+ modclient.hset("c", "v", "aaaaabaa")
+
+ q = Query("*=>[KNN 2 @v $vec]").return_field("__v_score").sort_by("__v_score", True)
+ res = modclient.ft().search(q, query_params={"vec": "aaaaaaaa"})
+
+ assert "a" == res.docs[0].id
+ assert "0" == res.docs[0].__getattribute__("__v_score")
+
+
+@pytest.mark.redismod
+def test_vector_field_error(modclient):
+ modclient.flushdb()
+
+ # sortable tag
+ with pytest.raises(Exception):
+ modclient.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),))
+
+ # not supported algorithm
+ with pytest.raises(Exception):
+ modclient.ft().create_index((VectorField("v", "SORT", {}),))
+
+
+@pytest.mark.redismod
def test_text_params(modclient):
modclient.flushdb()
modclient.ft().create_index((TextField("name"),))