diff options
author | Avital Fine <79420960+AvitalFineRedis@users.noreply.github.com> | 2021-11-25 14:45:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-25 15:45:19 +0200 |
commit | 393cd6280c6fb5394cc512ae15617236ecddac2e (patch) | |
tree | 67b0e59c10fec5bd7db984a1e2ff5351e7140e3b | |
parent | 3de2e6b6b1bc061d875d36a6f40598453ce85c58 (diff) | |
download | redis-py-393cd6280c6fb5394cc512ae15617236ecddac2e.tar.gz |
Support RediSearch FT.PROFILE command (#1727)
-rw-r--r-- | redis/commands/helpers.py | 41 | ||||
-rw-r--r-- | redis/commands/search/commands.py | 53 | ||||
-rw-r--r-- | tests/test_helpers.py | 26 | ||||
-rw-r--r-- | tests/test_search.py | 43 |
4 files changed, 156 insertions, 7 deletions
diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 46eb83d..5e8ff49 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -35,9 +35,12 @@ def delist(x): def parse_to_list(response): - """Optimistally parse the response to a list. - """ + """Optimistically parse the response to a list.""" res = [] + + if response is None: + return res + for item in response: try: res.append(int(item)) @@ -51,6 +54,40 @@ def parse_to_list(response): return res +def parse_list_to_dict(response): + res = {} + for i in range(0, len(response), 2): + if isinstance(response[i], list): + res['Child iterators'].append(parse_list_to_dict(response[i])) + elif isinstance(response[i+1], list): + res['Child iterators'] = [parse_list_to_dict(response[i+1])] + else: + try: + res[response[i]] = float(response[i+1]) + except (TypeError, ValueError): + res[response[i]] = response[i+1] + return res + + +def parse_to_dict(response): + if response is None: + return {} + + res = {} + for det in response: + if isinstance(det[1], list): + res[det[0]] = parse_list_to_dict(det[1]) + else: + try: # try to set the attribute. may be provided without value + try: # try to convert the value to float + res[det[0]] = float(det[1]) + except (TypeError, ValueError): + res[det[0]] = det[1] + except IndexError: + pass + return res + + def random_string(length=10): """ Returns a random N character long string. diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 0cee2ad..ed58255 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -7,6 +7,7 @@ from .query import Query from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .suggestion import SuggestionParser +from ..helpers import parse_to_dict NUMERIC = "NUMERIC" @@ -20,6 +21,7 @@ EXPLAIN_CMD = "FT.EXPLAIN" EXPLAINCLI_CMD = "FT.EXPLAINCLI" DEL_CMD = "FT.DEL" AGGREGATE_CMD = "FT.AGGREGATE" +PROFILE_CMD = "FT.PROFILE" CURSOR_CMD = "FT.CURSOR" SPELLCHECK_CMD = "FT.SPELLCHECK" DICT_ADD_CMD = "FT.DICTADD" @@ -382,11 +384,11 @@ class SearchCommands: def aggregate(self, query): """ - Issue an aggregation query + Issue an aggregation query. ### Parameters - **query**: This can be either an `AggeregateRequest`, or a `Cursor` + **query**: This can be either an `AggregateRequest`, or a `Cursor` An `AggregateResult` object is returned. You can access the rows from its `rows` property, which will always yield the rows of the result. @@ -401,6 +403,9 @@ class SearchCommands: raise ValueError("Bad query", query) raw = self.execute_command(*cmd) + return self._get_AggregateResult(raw, query, has_cursor) + + def _get_AggregateResult(self, raw, query, has_cursor): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] @@ -418,8 +423,48 @@ class SearchCommands: schema = None rows = raw[1:] - res = AggregateResult(rows, cursor, schema) - return res + return AggregateResult(rows, cursor, schema) + + def profile(self, query, limited=False): + """ + Performs a search or aggregate command and collects performance + information. + + ### Parameters + + **query**: This can be either an `AggregateRequest`, `Query` or + string. + **limited**: If set to True, removes details of reader iterator. + + """ + st = time.time() + cmd = [PROFILE_CMD, self.index_name, ""] + if limited: + cmd.append("LIMITED") + cmd.append('QUERY') + + if isinstance(query, AggregateRequest): + cmd[2] = "AGGREGATE" + cmd += query.build_args() + elif isinstance(query, Query): + cmd[2] = "SEARCH" + cmd += query.get_args() + else: + raise ValueError("Must provide AggregateRequest object or " + "Query object.") + + res = self.execute_command(*cmd) + + if isinstance(query, AggregateRequest): + result = self._get_AggregateResult(res[0], query, query._cursor) + else: + result = Result(res[0], + not query._no_content, + duration=(time.time() - st) * 1000.0, + has_payload=query._with_payloads, + with_scores=query._with_scores,) + + return result, parse_to_dict(res[1]) def spellcheck(self, query, distance=None, include=None, exclude=None): """ diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 467e00c..402eccf 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -5,7 +5,8 @@ from redis.commands.helpers import ( nativestr, parse_to_list, quote_string, - random_string + random_string, + parse_to_dict ) @@ -19,11 +20,34 @@ def test_list_or_args(): def test_parse_to_list(): + assert parse_to_list(None) == [] r = ["hello", b"my name", "45", "555.55", "is simon!", None] assert parse_to_list(r) == \ ["hello", "my name", 45, 555.55, "is simon!", None] +def test_parse_to_dict(): + assert parse_to_dict(None) == {} + r = [['Some number', '1.0345'], + ['Some string', 'hello'], + ['Child iterators', + ['Time', '0.2089', 'Counter', 3, 'Child iterators', + ['Type', 'bar', 'Time', '0.0729', 'Counter', 3], + ['Type', 'barbar', 'Time', '0.058', 'Counter', 3]]]] + assert parse_to_dict(r) == { + 'Child iterators': { + 'Child iterators': [ + {'Counter': 3.0, 'Time': 0.0729, 'Type': 'bar'}, + {'Counter': 3.0, 'Time': 0.058, 'Type': 'barbar'} + ], + 'Counter': 3.0, + 'Time': 0.2089 + }, + 'Some number': 1.0345, + 'Some string': 'hello' + } + + def test_nativestr(): assert nativestr('teststr') == 'teststr' assert nativestr(b'teststr') == 'teststr' diff --git a/tests/test_search.py b/tests/test_search.py index 0cba3b7..b65ac8d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1519,3 +1519,46 @@ def test_json_with_jsonpath(client): assert res.docs[0].id == "doc:1" with pytest.raises(Exception): res.docs[0].name_unsupported + + +@pytest.mark.redismod +def test_profile(client): + client.ft().create_index((TextField('t'),)) + client.ft().client.hset('1', 't', 'hello') + client.ft().client.hset('2', 't', 'world') + + # check using Query + q = Query('hello|world').no_content() + res, det = client.ft().profile(q) + assert det['Iterators profile']['Counter'] == 2.0 + assert len(det['Iterators profile']['Child iterators']) == 2 + assert det['Iterators profile']['Type'] == 'UNION' + assert det['Parsing time'] < 0.3 + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = aggregations.AggregateRequest("*").load("t")\ + .apply(prefix="startswith(@t, 'hel')") + res, det = client.ft().profile(req) + assert det['Iterators profile']['Counter'] == 2.0 + assert det['Iterators profile']['Type'] == 'WILDCARD' + assert det['Parsing time'] < 0.3 + assert len(res.rows) == 2 # check also the search result + + +@pytest.mark.redismod +def test_profile_limited(client): + client.ft().create_index((TextField('t'),)) + client.ft().client.hset('1', 't', 'hello') + client.ft().client.hset('2', 't', 'hell') + client.ft().client.hset('3', 't', 'help') + client.ft().client.hset('4', 't', 'helowa') + + q = Query('%hell% hel*') + res, det = client.ft().profile(q, limited=True) + assert det['Iterators profile']['Child iterators'][0]['Child iterators'] \ + == 'The number of iterators in the union is 3' + assert det['Iterators profile']['Child iterators'][1]['Child iterators'] \ + == 'The number of iterators in the union is 4' + assert det['Iterators profile']['Type'] == 'INTERSECT' + assert len(res.docs) == 3 # check also the search result |