summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAvital Fine <79420960+AvitalFineRedis@users.noreply.github.com>2021-11-25 14:45:19 +0100
committerGitHub <noreply@github.com>2021-11-25 15:45:19 +0200
commit393cd6280c6fb5394cc512ae15617236ecddac2e (patch)
tree67b0e59c10fec5bd7db984a1e2ff5351e7140e3b
parent3de2e6b6b1bc061d875d36a6f40598453ce85c58 (diff)
downloadredis-py-393cd6280c6fb5394cc512ae15617236ecddac2e.tar.gz
Support RediSearch FT.PROFILE command (#1727)
-rw-r--r--redis/commands/helpers.py41
-rw-r--r--redis/commands/search/commands.py53
-rw-r--r--tests/test_helpers.py26
-rw-r--r--tests/test_search.py43
4 files changed, 156 insertions, 7 deletions
diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py
index 46eb83d..5e8ff49 100644
--- a/redis/commands/helpers.py
+++ b/redis/commands/helpers.py
@@ -35,9 +35,12 @@ def delist(x):
def parse_to_list(response):
- """Optimistally parse the response to a list.
- """
+ """Optimistically parse the response to a list."""
res = []
+
+ if response is None:
+ return res
+
for item in response:
try:
res.append(int(item))
@@ -51,6 +54,40 @@ def parse_to_list(response):
return res
+def parse_list_to_dict(response):
+ res = {}
+ for i in range(0, len(response), 2):
+ if isinstance(response[i], list):
+ res['Child iterators'].append(parse_list_to_dict(response[i]))
+ elif isinstance(response[i+1], list):
+ res['Child iterators'] = [parse_list_to_dict(response[i+1])]
+ else:
+ try:
+ res[response[i]] = float(response[i+1])
+ except (TypeError, ValueError):
+ res[response[i]] = response[i+1]
+ return res
+
+
+def parse_to_dict(response):
+ if response is None:
+ return {}
+
+ res = {}
+ for det in response:
+ if isinstance(det[1], list):
+ res[det[0]] = parse_list_to_dict(det[1])
+ else:
+ try: # try to set the attribute. may be provided without value
+ try: # try to convert the value to float
+ res[det[0]] = float(det[1])
+ except (TypeError, ValueError):
+ res[det[0]] = det[1]
+ except IndexError:
+ pass
+ return res
+
+
def random_string(length=10):
"""
Returns a random N character long string.
diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py
index 0cee2ad..ed58255 100644
--- a/redis/commands/search/commands.py
+++ b/redis/commands/search/commands.py
@@ -7,6 +7,7 @@ from .query import Query
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
from .suggestion import SuggestionParser
+from ..helpers import parse_to_dict
NUMERIC = "NUMERIC"
@@ -20,6 +21,7 @@ EXPLAIN_CMD = "FT.EXPLAIN"
EXPLAINCLI_CMD = "FT.EXPLAINCLI"
DEL_CMD = "FT.DEL"
AGGREGATE_CMD = "FT.AGGREGATE"
+PROFILE_CMD = "FT.PROFILE"
CURSOR_CMD = "FT.CURSOR"
SPELLCHECK_CMD = "FT.SPELLCHECK"
DICT_ADD_CMD = "FT.DICTADD"
@@ -382,11 +384,11 @@ class SearchCommands:
def aggregate(self, query):
"""
- Issue an aggregation query
+ Issue an aggregation query.
### Parameters
- **query**: This can be either an `AggeregateRequest`, or a `Cursor`
+ **query**: This can be either an `AggregateRequest`, or a `Cursor`
An `AggregateResult` object is returned. You can access the rows from
its `rows` property, which will always yield the rows of the result.
@@ -401,6 +403,9 @@ class SearchCommands:
raise ValueError("Bad query", query)
raw = self.execute_command(*cmd)
+ return self._get_AggregateResult(raw, query, has_cursor)
+
+ def _get_AggregateResult(self, raw, query, has_cursor):
if has_cursor:
if isinstance(query, Cursor):
query.cid = raw[1]
@@ -418,8 +423,48 @@ class SearchCommands:
schema = None
rows = raw[1:]
- res = AggregateResult(rows, cursor, schema)
- return res
+ return AggregateResult(rows, cursor, schema)
+
+ def profile(self, query, limited=False):
+ """
+ Performs a search or aggregate command and collects performance
+ information.
+
+ ### Parameters
+
+ **query**: This can be either an `AggregateRequest`, `Query` or
+ string.
+ **limited**: If set to True, removes details of reader iterator.
+
+ """
+ st = time.time()
+ cmd = [PROFILE_CMD, self.index_name, ""]
+ if limited:
+ cmd.append("LIMITED")
+ cmd.append('QUERY')
+
+ if isinstance(query, AggregateRequest):
+ cmd[2] = "AGGREGATE"
+ cmd += query.build_args()
+ elif isinstance(query, Query):
+ cmd[2] = "SEARCH"
+ cmd += query.get_args()
+ else:
+ raise ValueError("Must provide AggregateRequest object or "
+ "Query object.")
+
+ res = self.execute_command(*cmd)
+
+ if isinstance(query, AggregateRequest):
+ result = self._get_AggregateResult(res[0], query, query._cursor)
+ else:
+ result = Result(res[0],
+ not query._no_content,
+ duration=(time.time() - st) * 1000.0,
+ has_payload=query._with_payloads,
+ with_scores=query._with_scores,)
+
+ return result, parse_to_dict(res[1])
def spellcheck(self, query, distance=None, include=None, exclude=None):
"""
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 467e00c..402eccf 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -5,7 +5,8 @@ from redis.commands.helpers import (
nativestr,
parse_to_list,
quote_string,
- random_string
+ random_string,
+ parse_to_dict
)
@@ -19,11 +20,34 @@ def test_list_or_args():
def test_parse_to_list():
+ assert parse_to_list(None) == []
r = ["hello", b"my name", "45", "555.55", "is simon!", None]
assert parse_to_list(r) == \
["hello", "my name", 45, 555.55, "is simon!", None]
+def test_parse_to_dict():
+ assert parse_to_dict(None) == {}
+ r = [['Some number', '1.0345'],
+ ['Some string', 'hello'],
+ ['Child iterators',
+ ['Time', '0.2089', 'Counter', 3, 'Child iterators',
+ ['Type', 'bar', 'Time', '0.0729', 'Counter', 3],
+ ['Type', 'barbar', 'Time', '0.058', 'Counter', 3]]]]
+ assert parse_to_dict(r) == {
+ 'Child iterators': {
+ 'Child iterators': [
+ {'Counter': 3.0, 'Time': 0.0729, 'Type': 'bar'},
+ {'Counter': 3.0, 'Time': 0.058, 'Type': 'barbar'}
+ ],
+ 'Counter': 3.0,
+ 'Time': 0.2089
+ },
+ 'Some number': 1.0345,
+ 'Some string': 'hello'
+ }
+
+
def test_nativestr():
assert nativestr('teststr') == 'teststr'
assert nativestr(b'teststr') == 'teststr'
diff --git a/tests/test_search.py b/tests/test_search.py
index 0cba3b7..b65ac8d 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -1519,3 +1519,46 @@ def test_json_with_jsonpath(client):
assert res.docs[0].id == "doc:1"
with pytest.raises(Exception):
res.docs[0].name_unsupported
+
+
+@pytest.mark.redismod
+def test_profile(client):
+ client.ft().create_index((TextField('t'),))
+ client.ft().client.hset('1', 't', 'hello')
+ client.ft().client.hset('2', 't', 'world')
+
+ # check using Query
+ q = Query('hello|world').no_content()
+ res, det = client.ft().profile(q)
+ assert det['Iterators profile']['Counter'] == 2.0
+ assert len(det['Iterators profile']['Child iterators']) == 2
+ assert det['Iterators profile']['Type'] == 'UNION'
+ assert det['Parsing time'] < 0.3
+ assert len(res.docs) == 2 # check also the search result
+
+ # check using AggregateRequest
+ req = aggregations.AggregateRequest("*").load("t")\
+ .apply(prefix="startswith(@t, 'hel')")
+ res, det = client.ft().profile(req)
+ assert det['Iterators profile']['Counter'] == 2.0
+ assert det['Iterators profile']['Type'] == 'WILDCARD'
+ assert det['Parsing time'] < 0.3
+ assert len(res.rows) == 2 # check also the search result
+
+
+@pytest.mark.redismod
+def test_profile_limited(client):
+ client.ft().create_index((TextField('t'),))
+ client.ft().client.hset('1', 't', 'hello')
+ client.ft().client.hset('2', 't', 'hell')
+ client.ft().client.hset('3', 't', 'help')
+ client.ft().client.hset('4', 't', 'helowa')
+
+ q = Query('%hell% hel*')
+ res, det = client.ft().profile(q, limited=True)
+ assert det['Iterators profile']['Child iterators'][0]['Child iterators'] \
+ == 'The number of iterators in the union is 3'
+ assert det['Iterators profile']['Child iterators'][1]['Child iterators'] \
+ == 'The number of iterators in the union is 4'
+ assert det['Iterators profile']['Type'] == 'INTERSECT'
+ assert len(res.docs) == 3 # check also the search result