diff options
author | Avital Fine <98389525+Avital-Fine@users.noreply.github.com> | 2022-03-23 13:02:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-23 14:02:46 +0200 |
commit | 019f651c822c86525e57ad85a7e19f72757255bb (patch) | |
tree | 19e08fbba017f5917fd291c57851ccda203ad585 /tests/test_search.py | |
parent | 827dcde5c0af5f7aa9bdc3999fc86aa2ba945118 (diff) | |
download | redis-py-019f651c822c86525e57ad85a7e19f72757255bb.tar.gz |
Support for Vector Fields for Vector Similarity Search (#2041)
* Support Vector field in FT.CREATE command
* linters
* fix data error
* change to dic
* add type hints and docstring to constructor
* test not supported algorithm
* linters
* fix errors
* example
* delete example
Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'tests/test_search.py')
-rw-r--r-- | tests/test_search.py | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tests/test_search.py b/tests/test_search.py index 5dea739..b94ae05 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -12,7 +12,13 @@ import redis.commands.search.aggregation as aggregations import redis.commands.search.reducers as reducers from redis.commands.json.path import Path from redis.commands.search import Search -from redis.commands.search.field import GeoField, NumericField, TagField, TextField +from redis.commands.search.field import ( + GeoField, + NumericField, + TagField, + TextField, + VectorField, +) from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result @@ -1523,6 +1529,40 @@ def test_profile_limited(client): @pytest.mark.redismod +def test_vector_field(modclient): + modclient.flushdb() + modclient.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + modclient.hset("a", "v", "aaaaaaaa") + modclient.hset("b", "v", "aaaabaaa") + modclient.hset("c", "v", "aaaaabaa") + + q = Query("*=>[KNN 2 @v $vec]").return_field("__v_score").sort_by("__v_score", True) + res = modclient.ft().search(q, query_params={"vec": "aaaaaaaa"}) + + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + + +@pytest.mark.redismod +def test_vector_field_error(modclient): + modclient.flushdb() + + # sortable tag + with pytest.raises(Exception): + modclient.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) + + # not supported algorithm + with pytest.raises(Exception): + modclient.ft().create_index((VectorField("v", "SORT", {}),)) + + +@pytest.mark.redismod def test_text_params(modclient): modclient.flushdb() modclient.ft().create_index((TextField("name"),)) |