From 0ed06603695ba9533d1086dcd7d60cd5eb5e17d0 Mon Sep 17 00:00:00 2001 From: Chayim Date: Wed, 16 Feb 2022 10:26:31 +0200 Subject: 4.1.4 release cherry-picks (#1994) --- .github/workflows/integration.yaml | 2 + redis/commands/graph/commands.py | 52 +++++---- redis/commands/graph/execution_plan.py | 208 +++++++++++++++++++++++++++++++++ redis/commands/search/commands.py | 42 +++++-- setup.py | 2 +- tests/test_graph.py | 127 +++++++++++++++++--- tests/test_search.py | 55 ++++++++- 7 files changed, 443 insertions(+), 45 deletions(-) create mode 100644 redis/commands/graph/execution_plan.py diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index a73bb71..1d8a33a 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -8,9 +8,11 @@ on: - '**/*.md' branches: - master + - '[0-9].[0-9]' pull_request: branches: - master + - '[0-9].[0-9]' jobs: diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py index 157f628..baed1fc 100644 --- a/redis/commands/graph/commands.py +++ b/redis/commands/graph/commands.py @@ -2,6 +2,7 @@ from redis import DataError from redis.exceptions import ResponseError from .exceptions import VersionMismatchException +from .execution_plan import ExecutionPlan from .query_result import QueryResult @@ -118,27 +119,6 @@ class GraphCommands: self.nodes = {} self.edges = [] - def explain(self, query, params=None): - """ - Get the execution plan for given query, - Returns an array of operations. - For more information see `GRAPH.EXPLAIN `_. # noqa - - Args: - - query: - The query that will be executed. - params: dict - Query parameters. - """ - if params is not None: - query = self._build_params_header(params) + query - - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) - if isinstance(plan[0], bytes): - plan = [b.decode() for b in plan] - return "\n".join(plan) - def bulk(self, **kwargs): """Internal only. Not supported.""" raise NotImplementedError( @@ -200,3 +180,33 @@ class GraphCommands: For more information see `GRAPH.LIST `_. # noqa """ return self.execute_command("GRAPH.LIST") + + def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + if params is not None: + query = self._build_params_header(params) + query + + plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + return "\n".join(plan) + + def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + + Args: + query: the query that will be executed + params: query parameters + """ + if params is not None: + query = self._build_params_header(params) + query + + plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + return ExecutionPlan(plan) diff --git a/redis/commands/graph/execution_plan.py b/redis/commands/graph/execution_plan.py new file mode 100644 index 0000000..a9b4ad0 --- /dev/null +++ b/redis/commands/graph/execution_plan.py @@ -0,0 +1,208 @@ +import re + + +class ProfileStats: + """ + ProfileStats, runtime execution statistics of operation. + """ + + def __init__(self, records_produced, execution_time): + self.records_produced = records_produced + self.execution_time = execution_time + + +class Operation: + """ + Operation, single operation within execution plan. + """ + + def __init__(self, name, args=None, profile_stats=None): + """ + Create a new operation. + + Args: + name: string that represents the name of the operation + args: operation arguments + profile_stats: profile statistics + """ + self.name = name + self.args = args + self.profile_stats = profile_stats + self.children = [] + + def append_child(self, child): + if not isinstance(child, Operation) or self is child: + raise Exception("child must be Operation") + + self.children.append(child) + return self + + def child_count(self): + return len(self.children) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, Operation): + return False + + return self.name == o.name and self.args == o.args + + def __str__(self) -> str: + args_str = "" if self.args is None else " | " + self.args + return f"{self.name}{args_str}" + + +class ExecutionPlan: + """ + ExecutionPlan, collection of operations. + """ + + def __init__(self, plan): + """ + Create a new execution plan. + + Args: + plan: array of strings that represents the collection operations + the output from GRAPH.EXPLAIN + """ + if not isinstance(plan, list): + raise Exception("plan must be an array") + + self.plan = plan + self.structured_plan = self._operation_tree() + + def _compare_operations(self, root_a, root_b): + """ + Compare execution plan operation tree + + Return: True if operation trees are equal, False otherwise + """ + + # compare current root + if root_a != root_b: + return False + + # make sure root have the same number of children + if root_a.child_count() != root_b.child_count(): + return False + + # recursively compare children + for i in range(root_a.child_count()): + if not self._compare_operations(root_a.children[i], root_b.children[i]): + return False + + return True + + def __str__(self) -> str: + def aggraget_str(str_children): + return "\n".join( + [ + " " + line + for str_child in str_children + for line in str_child.splitlines() + ] + ) + + def combine_str(x, y): + return f"{x}\n{y}" + + return self._operation_traverse( + self.structured_plan, str, aggraget_str, combine_str + ) + + def __eq__(self, o: object) -> bool: + """Compares two execution plans + + Return: True if the two plans are equal False otherwise + """ + # make sure 'o' is an execution-plan + if not isinstance(o, ExecutionPlan): + return False + + # get root for both plans + root_a = self.structured_plan + root_b = o.structured_plan + + # compare execution trees + return self._compare_operations(root_a, root_b) + + def _operation_traverse(self, op, op_f, aggregate_f, combine_f): + """ + Traverse operation tree recursively applying functions + + Args: + op: operation to traverse + op_f: function applied for each operation + aggregate_f: aggregation function applied for all children of a single operation + combine_f: combine function applied for the operation result and the children result + """ # noqa + # apply op_f for each operation + op_res = op_f(op) + if len(op.children) == 0: + return op_res # no children return + else: + # apply _operation_traverse recursively + children = [ + self._operation_traverse(child, op_f, aggregate_f, combine_f) + for child in op.children + ] + # combine the operation result with the children aggregated result + return combine_f(op_res, aggregate_f(children)) + + def _operation_tree(self): + """Build the operation tree from the string representation""" + + # initial state + i = 0 + level = 0 + stack = [] + current = None + + def _create_operation(args): + profile_stats = None + name = args[0].strip() + args.pop(0) + if len(args) > 0 and "Records produced" in args[-1]: + records_produced = int( + re.search("Records produced: (\\d+)", args[-1]).group(1) + ) + execution_time = float( + re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1) + ) + profile_stats = ProfileStats(records_produced, execution_time) + args.pop(-1) + return Operation( + name, None if len(args) == 0 else args[0].strip(), profile_stats + ) + + # iterate plan operations + while i < len(self.plan): + current_op = self.plan[i] + op_level = current_op.count(" ") + if op_level == level: + # if the operation level equal to the current level + # set the current operation and move next + child = _create_operation(current_op.split("|")) + if current: + current = stack.pop() + current.append_child(child) + current = child + i += 1 + elif op_level == level + 1: + # if the operation is child of the current operation + # add it as child and set as current operation + child = _create_operation(current_op.split("|")) + current.append_child(child) + stack.append(current) + current = child + level += 1 + i += 1 + elif op_level < level: + # if the operation is not child of current operation + # go back to it's parent operation + levels_back = level - op_level + 1 + for _ in range(levels_back): + current = stack.pop() + level -= levels_back + else: + raise Exception("corrupted plan") + return stack[0] diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 3f768ab..5bcfdfd 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,5 +1,6 @@ import itertools import time +from typing import Dict, Union from ..helpers import parse_to_dict from ._util import to_string @@ -377,7 +378,17 @@ class SearchCommands: it = map(to_string, res) return dict(zip(it, it)) - def _mk_query_args(self, query): + def get_params_args(self, query_params: Dict[str, Union[str, int, float]]): + args = [] + if len(query_params) > 0: + args.append("params") + args.append(len(query_params) * 2) + for key, value in query_params.items(): + args.append(key) + args.append(value) + return args + + def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]]): args = [self.index_name] if isinstance(query, str): @@ -387,9 +398,16 @@ class SearchCommands: raise ValueError(f"Bad query type {type(query)}") args += query.get_args() + if query_params is not None: + args += self.get_params_args(query_params) + return args, query - def search(self, query): + def search( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): """ Search the index for a given query, and return a result of documents @@ -401,7 +419,7 @@ class SearchCommands: For more information: https://oss.redis.com/redisearch/Commands/#ftsearch """ # noqa - args, query = self._mk_query_args(query) + args, query = self._mk_query_args(query, query_params=query_params) st = time.time() res = self.execute_command(SEARCH_CMD, *args) @@ -413,18 +431,26 @@ class SearchCommands: with_scores=query._with_scores, ) - def explain(self, query): + def explain( + self, + query=Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): """Returns the execution plan for a complex query. For more information: https://oss.redis.com/redisearch/Commands/#ftexplain """ # noqa - args, query_text = self._mk_query_args(query) + args, query_text = self._mk_query_args(query, query_params=query_params) return self.execute_command(EXPLAIN_CMD, *args) - def explain_cli(self, query): # noqa + def explain_cli(self, query: Union[str, Query]): # noqa raise NotImplementedError("EXPLAINCLI will not be implemented.") - def aggregate(self, query): + def aggregate( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): """ Issue an aggregation query. @@ -445,6 +471,8 @@ class SearchCommands: cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() else: raise ValueError("Bad query", query) + if query_params is not None: + cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) return self._get_AggregateResult(raw, query, has_cursor) diff --git a/setup.py b/setup.py index f085f03..cfe1c32 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="4.1.3", + version="4.1.4", packages=find_packages( include=[ "redis", diff --git a/tests/test_graph.py b/tests/test_graph.py index c6dc9a4..c885aa4 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,6 +1,7 @@ import pytest from redis.commands.graph import Edge, Node, Path +from redis.commands.graph.execution_plan import Operation from redis.exceptions import ResponseError @@ -259,21 +260,6 @@ def test_cached_execution(client): assert cached_result.cached_execution -@pytest.mark.redismod -def test_explain(client): - create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - client.graph().query(create_query) - - result = client.graph().explain( - "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa - {"name": "Yehuda"}, - ) - expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa - assert result == expected - - @pytest.mark.redismod def test_slowlog(client): create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), @@ -475,3 +461,114 @@ def test_cache_sync(client): assert A._properties[1] == "x" assert A._relationshipTypes[0] == "S" assert A._relationshipTypes[1] == "R" + + +@pytest.mark.redismod +def test_execution_plan(client): + redis_graph = client.graph("execution_plan") + create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), + (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), + (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + redis_graph.query(create_query) + + result = redis_graph.execution_plan( + "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa + {"name": "Yehuda"}, + ) + expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa + assert result == expected + + redis_graph.delete() + + +@pytest.mark.redismod +def test_explain(client): + redis_graph = client.graph("execution_plan") + # graph creation / population + create_query = """CREATE +(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), +(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), +(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + redis_graph.query(create_query) + + result = redis_graph.explain( + """MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name +UNION +MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name""", + {"name": "Yamaha"}, + ) + expected = """\ +Results +Distinct + Join + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team) + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Distinct").append_child( + Operation("Join") + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + ) + ) + + assert result.structured_plan == expected + + result = redis_graph.explain( + """MATCH (r:Rider), (t:Team) + RETURN r.name, t.name""" + ) + expected = """\ +Results +Project + Cartesian Product + Node By Label Scan | (r:Rider) + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Project").append_child( + Operation("Cartesian Product") + .append_child(Operation("Node By Label Scan")) + .append_child(Operation("Node By Label Scan")) + ) + ) + + assert result.structured_plan == expected + + redis_graph.delete() diff --git a/tests/test_search.py b/tests/test_search.py index 6c79041..0d15dcc 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -964,7 +964,7 @@ def test_aggregations_groupby(client): res = client.ft().aggregate(req).rows[0] assert res[1] == "redis" - assert res[3] == "10" + assert res[3] == "8" # median of 3,8,10 req = aggregations.AggregateRequest("redis").group_by( "@parent", @@ -1521,3 +1521,56 @@ def test_profile_limited(client): ) assert det["Iterators profile"]["Type"] == "INTERSECT" assert len(res.docs) == 3 # check also the search result + + +@pytest.mark.redismod +def test_text_params(modclient): + modclient.flushdb() + modclient.ft().create_index((TextField("name"),)) + + modclient.ft().add_document("doc1", name="Alice") + modclient.ft().add_document("doc2", name="Bob") + modclient.ft().add_document("doc3", name="Carol") + + params_dict = {"name1": "Alice", "name2": "Bob"} + q = Query("@name:($name1 | $name2 )") + res = modclient.ft().search(q, query_params=params_dict) + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + + +@pytest.mark.redismod +def test_numeric_params(modclient): + modclient.flushdb() + modclient.ft().create_index((NumericField("numval"),)) + + modclient.ft().add_document("doc1", numval=101) + modclient.ft().add_document("doc2", numval=102) + modclient.ft().add_document("doc3", numval=103) + + params_dict = {"min": 101, "max": 102} + q = Query("@numval:[$min $max]") + res = modclient.ft().search(q, query_params=params_dict) + + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + + +@pytest.mark.redismod +def test_geo_params(modclient): + + modclient.flushdb() + modclient.ft().create_index((GeoField("g"))) + modclient.ft().add_document("doc1", g="29.69465, 34.95126") + modclient.ft().add_document("doc2", g="29.69350, 34.94737") + modclient.ft().add_document("doc3", g="29.68746, 34.94882") + + params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} + q = Query("@g:[$lon $lat $radius $units]") + res = modclient.ft().search(q, query_params=params_dict) + assert 3 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "doc3" == res.docs[2].id -- cgit v1.2.1