From 7880460b72aca49aa5b9512f0995c0d17d884a7d Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Wed, 1 Jun 2022 14:32:45 +0300 Subject: Add `query_params` to FT.PROFILE (#2198) * ft.profile query_params * fix pr comments * type hints --- redis/commands/search/commands.py | 30 +++++++++++++++++++----------- tests/test_search.py | 24 ++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index bf66147..0121436 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,6 +1,6 @@ import itertools import time -from typing import Dict, Union +from typing import Dict, Optional, Union from redis.client import Pipeline @@ -363,7 +363,11 @@ class SearchCommands: it = map(to_string, res) return dict(zip(it, it)) - def get_params_args(self, query_params: Dict[str, Union[str, int, float]]): + def get_params_args( + self, query_params: Union[Dict[str, Union[str, int, float]], None] + ): + if query_params is None: + return [] args = [] if len(query_params) > 0: args.append("params") @@ -383,8 +387,7 @@ class SearchCommands: raise ValueError(f"Bad query type {type(query)}") args += query.get_args() - if query_params is not None: - args += self.get_params_args(query_params) + args += self.get_params_args(query_params) return args, query @@ -459,8 +462,7 @@ class SearchCommands: cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() else: raise ValueError("Bad query", query) - if query_params is not None: - cmd += self.get_params_args(query_params) + cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) return self._get_aggregate_result(raw, query, has_cursor) @@ -485,16 +487,22 @@ class SearchCommands: return AggregateResult(rows, cursor, schema) - def profile(self, query, limited=False): + def profile( + self, + query: Union[str, Query, AggregateRequest], + limited: bool = False, + query_params: Optional[Dict[str, Union[str, int, float]]] = None, + ): """ Performs a search or aggregate command and collects performance information. ### Parameters - **query**: This can be either an `AggregateRequest`, `Query` or - string. + **query**: This can be either an `AggregateRequest`, `Query` or string. **limited**: If set to True, removes details of reader iterator. + **query_params**: Define one or more value parameters. + Each parameter has a name and a value. """ st = time.time() @@ -509,6 +517,7 @@ class SearchCommands: elif isinstance(query, Query): cmd[2] = "SEARCH" cmd += query.get_args() + cmd += self.get_params_args(query_params) else: raise ValueError("Must provide AggregateRequest object or " "Query object.") @@ -907,8 +916,7 @@ class AsyncSearchCommands(SearchCommands): cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() else: raise ValueError("Bad query", query) - if query_params is not None: - cmd += self.get_params_args(query_params) + cmd += self.get_params_args(query_params) raw = await self.execute_command(*cmd) return self._get_aggregate_result(raw, query, has_cursor) diff --git a/tests/test_search.py b/tests/test_search.py index dba914a..f0a1190 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1519,6 +1519,30 @@ def test_profile_limited(client): assert len(res.docs) == 3 # check also the search result +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +def test_profile_query_params(modclient: redis.Redis): + modclient.flushdb() + modclient.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + modclient.hset("a", "v", "aaaaaaaa") + modclient.hset("b", "v", "aaaabaaa") + modclient.hset("c", "v", "aaaaabaa") + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) + res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + assert det["Iterators profile"]["Counter"] == 2.0 + assert det["Iterators profile"]["Type"] == "VECTOR" + assert res.total == 2 + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + + @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_vector_field(modclient): -- cgit v1.2.1