summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChayim <chayim@users.noreply.github.com>2022-02-16 10:26:31 +0200
committerGitHub <noreply@github.com>2022-02-16 10:26:31 +0200
commit0ed06603695ba9533d1086dcd7d60cd5eb5e17d0 (patch)
tree530e96469a396c221e28ac6a8b369984d01bd66c
parent6c00e091e93d07834fcdd811b2a8473848310db0 (diff)
downloadredis-py-4.1.tar.gz
4.1.4 release cherry-picks (#1994)v4.1.44.1
-rw-r--r--.github/workflows/integration.yaml2
-rw-r--r--redis/commands/graph/commands.py52
-rw-r--r--redis/commands/graph/execution_plan.py208
-rw-r--r--redis/commands/search/commands.py42
-rw-r--r--setup.py2
-rw-r--r--tests/test_graph.py127
-rw-r--r--tests/test_search.py55
7 files changed, 443 insertions, 45 deletions
diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml
index a73bb71..1d8a33a 100644
--- a/.github/workflows/integration.yaml
+++ b/.github/workflows/integration.yaml
@@ -8,9 +8,11 @@ on:
- '**/*.md'
branches:
- master
+ - '[0-9].[0-9]'
pull_request:
branches:
- master
+ - '[0-9].[0-9]'
jobs:
diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py
index 157f628..baed1fc 100644
--- a/redis/commands/graph/commands.py
+++ b/redis/commands/graph/commands.py
@@ -2,6 +2,7 @@ from redis import DataError
from redis.exceptions import ResponseError
from .exceptions import VersionMismatchException
+from .execution_plan import ExecutionPlan
from .query_result import QueryResult
@@ -118,27 +119,6 @@ class GraphCommands:
self.nodes = {}
self.edges = []
- def explain(self, query, params=None):
- """
- Get the execution plan for given query,
- Returns an array of operations.
- For more information see `GRAPH.EXPLAIN <https://oss.redis.com/redisgraph/master/commands/#graphexplain>`_. # noqa
-
- Args:
-
- query:
- The query that will be executed.
- params: dict
- Query parameters.
- """
- if params is not None:
- query = self._build_params_header(params) + query
-
- plan = self.execute_command("GRAPH.EXPLAIN", self.name, query)
- if isinstance(plan[0], bytes):
- plan = [b.decode() for b in plan]
- return "\n".join(plan)
-
def bulk(self, **kwargs):
"""Internal only. Not supported."""
raise NotImplementedError(
@@ -200,3 +180,33 @@ class GraphCommands:
For more information see `GRAPH.LIST <https://oss.redis.com/redisgraph/master/commands/#graphlist>`_. # noqa
"""
return self.execute_command("GRAPH.LIST")
+
+ def execution_plan(self, query, params=None):
+ """
+ Get the execution plan for given query,
+ GRAPH.EXPLAIN returns an array of operations.
+
+ Args:
+ query: the query that will be executed
+ params: query parameters
+ """
+ if params is not None:
+ query = self._build_params_header(params) + query
+
+ plan = self.execute_command("GRAPH.EXPLAIN", self.name, query)
+ return "\n".join(plan)
+
+ def explain(self, query, params=None):
+ """
+ Get the execution plan for given query,
+ GRAPH.EXPLAIN returns ExecutionPlan object.
+
+ Args:
+ query: the query that will be executed
+ params: query parameters
+ """
+ if params is not None:
+ query = self._build_params_header(params) + query
+
+ plan = self.execute_command("GRAPH.EXPLAIN", self.name, query)
+ return ExecutionPlan(plan)
diff --git a/redis/commands/graph/execution_plan.py b/redis/commands/graph/execution_plan.py
new file mode 100644
index 0000000..a9b4ad0
--- /dev/null
+++ b/redis/commands/graph/execution_plan.py
@@ -0,0 +1,208 @@
+import re
+
+
+class ProfileStats:
+ """
+ ProfileStats, runtime execution statistics of operation.
+ """
+
+ def __init__(self, records_produced, execution_time):
+ self.records_produced = records_produced
+ self.execution_time = execution_time
+
+
+class Operation:
+ """
+ Operation, single operation within execution plan.
+ """
+
+ def __init__(self, name, args=None, profile_stats=None):
+ """
+ Create a new operation.
+
+ Args:
+ name: string that represents the name of the operation
+ args: operation arguments
+ profile_stats: profile statistics
+ """
+ self.name = name
+ self.args = args
+ self.profile_stats = profile_stats
+ self.children = []
+
+ def append_child(self, child):
+ if not isinstance(child, Operation) or self is child:
+ raise Exception("child must be Operation")
+
+ self.children.append(child)
+ return self
+
+ def child_count(self):
+ return len(self.children)
+
+ def __eq__(self, o: object) -> bool:
+ if not isinstance(o, Operation):
+ return False
+
+ return self.name == o.name and self.args == o.args
+
+ def __str__(self) -> str:
+ args_str = "" if self.args is None else " | " + self.args
+ return f"{self.name}{args_str}"
+
+
+class ExecutionPlan:
+ """
+ ExecutionPlan, collection of operations.
+ """
+
+ def __init__(self, plan):
+ """
+ Create a new execution plan.
+
+ Args:
+ plan: array of strings that represents the collection operations
+ the output from GRAPH.EXPLAIN
+ """
+ if not isinstance(plan, list):
+ raise Exception("plan must be an array")
+
+ self.plan = plan
+ self.structured_plan = self._operation_tree()
+
+ def _compare_operations(self, root_a, root_b):
+ """
+ Compare execution plan operation tree
+
+ Return: True if operation trees are equal, False otherwise
+ """
+
+ # compare current root
+ if root_a != root_b:
+ return False
+
+ # make sure root have the same number of children
+ if root_a.child_count() != root_b.child_count():
+ return False
+
+ # recursively compare children
+ for i in range(root_a.child_count()):
+ if not self._compare_operations(root_a.children[i], root_b.children[i]):
+ return False
+
+ return True
+
+ def __str__(self) -> str:
+ def aggraget_str(str_children):
+ return "\n".join(
+ [
+ " " + line
+ for str_child in str_children
+ for line in str_child.splitlines()
+ ]
+ )
+
+ def combine_str(x, y):
+ return f"{x}\n{y}"
+
+ return self._operation_traverse(
+ self.structured_plan, str, aggraget_str, combine_str
+ )
+
+ def __eq__(self, o: object) -> bool:
+ """Compares two execution plans
+
+ Return: True if the two plans are equal False otherwise
+ """
+ # make sure 'o' is an execution-plan
+ if not isinstance(o, ExecutionPlan):
+ return False
+
+ # get root for both plans
+ root_a = self.structured_plan
+ root_b = o.structured_plan
+
+ # compare execution trees
+ return self._compare_operations(root_a, root_b)
+
+ def _operation_traverse(self, op, op_f, aggregate_f, combine_f):
+ """
+ Traverse operation tree recursively applying functions
+
+ Args:
+ op: operation to traverse
+ op_f: function applied for each operation
+ aggregate_f: aggregation function applied for all children of a single operation
+ combine_f: combine function applied for the operation result and the children result
+ """ # noqa
+ # apply op_f for each operation
+ op_res = op_f(op)
+ if len(op.children) == 0:
+ return op_res # no children return
+ else:
+ # apply _operation_traverse recursively
+ children = [
+ self._operation_traverse(child, op_f, aggregate_f, combine_f)
+ for child in op.children
+ ]
+ # combine the operation result with the children aggregated result
+ return combine_f(op_res, aggregate_f(children))
+
+ def _operation_tree(self):
+ """Build the operation tree from the string representation"""
+
+ # initial state
+ i = 0
+ level = 0
+ stack = []
+ current = None
+
+ def _create_operation(args):
+ profile_stats = None
+ name = args[0].strip()
+ args.pop(0)
+ if len(args) > 0 and "Records produced" in args[-1]:
+ records_produced = int(
+ re.search("Records produced: (\\d+)", args[-1]).group(1)
+ )
+ execution_time = float(
+ re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1)
+ )
+ profile_stats = ProfileStats(records_produced, execution_time)
+ args.pop(-1)
+ return Operation(
+ name, None if len(args) == 0 else args[0].strip(), profile_stats
+ )
+
+ # iterate plan operations
+ while i < len(self.plan):
+ current_op = self.plan[i]
+ op_level = current_op.count(" ")
+ if op_level == level:
+ # if the operation level equal to the current level
+ # set the current operation and move next
+ child = _create_operation(current_op.split("|"))
+ if current:
+ current = stack.pop()
+ current.append_child(child)
+ current = child
+ i += 1
+ elif op_level == level + 1:
+ # if the operation is child of the current operation
+ # add it as child and set as current operation
+ child = _create_operation(current_op.split("|"))
+ current.append_child(child)
+ stack.append(current)
+ current = child
+ level += 1
+ i += 1
+ elif op_level < level:
+ # if the operation is not child of current operation
+ # go back to it's parent operation
+ levels_back = level - op_level + 1
+ for _ in range(levels_back):
+ current = stack.pop()
+ level -= levels_back
+ else:
+ raise Exception("corrupted plan")
+ return stack[0]
diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py
index 3f768ab..5bcfdfd 100644
--- a/redis/commands/search/commands.py
+++ b/redis/commands/search/commands.py
@@ -1,5 +1,6 @@
import itertools
import time
+from typing import Dict, Union
from ..helpers import parse_to_dict
from ._util import to_string
@@ -377,7 +378,17 @@ class SearchCommands:
it = map(to_string, res)
return dict(zip(it, it))
- def _mk_query_args(self, query):
+ def get_params_args(self, query_params: Dict[str, Union[str, int, float]]):
+ args = []
+ if len(query_params) > 0:
+ args.append("params")
+ args.append(len(query_params) * 2)
+ for key, value in query_params.items():
+ args.append(key)
+ args.append(value)
+ return args
+
+ def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]]):
args = [self.index_name]
if isinstance(query, str):
@@ -387,9 +398,16 @@ class SearchCommands:
raise ValueError(f"Bad query type {type(query)}")
args += query.get_args()
+ if query_params is not None:
+ args += self.get_params_args(query_params)
+
return args, query
- def search(self, query):
+ def search(
+ self,
+ query: Union[str, Query],
+ query_params: Dict[str, Union[str, int, float]] = None,
+ ):
"""
Search the index for a given query, and return a result of documents
@@ -401,7 +419,7 @@ class SearchCommands:
For more information: https://oss.redis.com/redisearch/Commands/#ftsearch
""" # noqa
- args, query = self._mk_query_args(query)
+ args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = self.execute_command(SEARCH_CMD, *args)
@@ -413,18 +431,26 @@ class SearchCommands:
with_scores=query._with_scores,
)
- def explain(self, query):
+ def explain(
+ self,
+ query=Union[str, Query],
+ query_params: Dict[str, Union[str, int, float]] = None,
+ ):
"""Returns the execution plan for a complex query.
For more information: https://oss.redis.com/redisearch/Commands/#ftexplain
""" # noqa
- args, query_text = self._mk_query_args(query)
+ args, query_text = self._mk_query_args(query, query_params=query_params)
return self.execute_command(EXPLAIN_CMD, *args)
- def explain_cli(self, query): # noqa
+ def explain_cli(self, query: Union[str, Query]): # noqa
raise NotImplementedError("EXPLAINCLI will not be implemented.")
- def aggregate(self, query):
+ def aggregate(
+ self,
+ query: Union[str, Query],
+ query_params: Dict[str, Union[str, int, float]] = None,
+ ):
"""
Issue an aggregation query.
@@ -445,6 +471,8 @@ class SearchCommands:
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
else:
raise ValueError("Bad query", query)
+ if query_params is not None:
+ cmd += self.get_params_args(query_params)
raw = self.execute_command(*cmd)
return self._get_AggregateResult(raw, query, has_cursor)
diff --git a/setup.py b/setup.py
index f085f03..cfe1c32 100644
--- a/setup.py
+++ b/setup.py
@@ -8,7 +8,7 @@ setup(
long_description_content_type="text/markdown",
keywords=["Redis", "key-value store", "database"],
license="MIT",
- version="4.1.3",
+ version="4.1.4",
packages=find_packages(
include=[
"redis",
diff --git a/tests/test_graph.py b/tests/test_graph.py
index c6dc9a4..c885aa4 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -1,6 +1,7 @@
import pytest
from redis.commands.graph import Edge, Node, Path
+from redis.commands.graph.execution_plan import Operation
from redis.exceptions import ResponseError
@@ -260,21 +261,6 @@ def test_cached_execution(client):
@pytest.mark.redismod
-def test_explain(client):
- create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
- (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
- (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})"""
- client.graph().query(create_query)
-
- result = client.graph().explain(
- "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa
- {"name": "Yehuda"},
- )
- expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa
- assert result == expected
-
-
-@pytest.mark.redismod
def test_slowlog(client):
create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
@@ -475,3 +461,114 @@ def test_cache_sync(client):
assert A._properties[1] == "x"
assert A._relationshipTypes[0] == "S"
assert A._relationshipTypes[1] == "R"
+
+
+@pytest.mark.redismod
+def test_execution_plan(client):
+ redis_graph = client.graph("execution_plan")
+ create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
+ (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
+ (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})"""
+ redis_graph.query(create_query)
+
+ result = redis_graph.execution_plan(
+ "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa
+ {"name": "Yehuda"},
+ )
+ expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa
+ assert result == expected
+
+ redis_graph.delete()
+
+
+@pytest.mark.redismod
+def test_explain(client):
+ redis_graph = client.graph("execution_plan")
+ # graph creation / population
+ create_query = """CREATE
+(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
+(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
+(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})"""
+ redis_graph.query(create_query)
+
+ result = redis_graph.explain(
+ """MATCH (r:Rider)-[:rides]->(t:Team)
+WHERE t.name = $name
+RETURN r.name, t.name
+UNION
+MATCH (r:Rider)-[:rides]->(t:Team)
+WHERE t.name = $name
+RETURN r.name, t.name""",
+ {"name": "Yamaha"},
+ )
+ expected = """\
+Results
+Distinct
+ Join
+ Project
+ Conditional Traverse | (t:Team)->(r:Rider)
+ Filter
+ Node By Label Scan | (t:Team)
+ Project
+ Conditional Traverse | (t:Team)->(r:Rider)
+ Filter
+ Node By Label Scan | (t:Team)"""
+ assert str(result).replace(" ", "").replace("\n", "") == expected.replace(
+ " ", ""
+ ).replace("\n", "")
+
+ expected = Operation("Results").append_child(
+ Operation("Distinct").append_child(
+ Operation("Join")
+ .append_child(
+ Operation("Project").append_child(
+ Operation(
+ "Conditional Traverse", "(t:Team)->(r:Rider)"
+ ).append_child(
+ Operation("Filter").append_child(
+ Operation("Node By Label Scan", "(t:Team)")
+ )
+ )
+ )
+ )
+ .append_child(
+ Operation("Project").append_child(
+ Operation(
+ "Conditional Traverse", "(t:Team)->(r:Rider)"
+ ).append_child(
+ Operation("Filter").append_child(
+ Operation("Node By Label Scan", "(t:Team)")
+ )
+ )
+ )
+ )
+ )
+ )
+
+ assert result.structured_plan == expected
+
+ result = redis_graph.explain(
+ """MATCH (r:Rider), (t:Team)
+ RETURN r.name, t.name"""
+ )
+ expected = """\
+Results
+Project
+ Cartesian Product
+ Node By Label Scan | (r:Rider)
+ Node By Label Scan | (t:Team)"""
+ assert str(result).replace(" ", "").replace("\n", "") == expected.replace(
+ " ", ""
+ ).replace("\n", "")
+
+ expected = Operation("Results").append_child(
+ Operation("Project").append_child(
+ Operation("Cartesian Product")
+ .append_child(Operation("Node By Label Scan"))
+ .append_child(Operation("Node By Label Scan"))
+ )
+ )
+
+ assert result.structured_plan == expected
+
+ redis_graph.delete()
diff --git a/tests/test_search.py b/tests/test_search.py
index 6c79041..0d15dcc 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -964,7 +964,7 @@ def test_aggregations_groupby(client):
res = client.ft().aggregate(req).rows[0]
assert res[1] == "redis"
- assert res[3] == "10"
+ assert res[3] == "8" # median of 3,8,10
req = aggregations.AggregateRequest("redis").group_by(
"@parent",
@@ -1521,3 +1521,56 @@ def test_profile_limited(client):
)
assert det["Iterators profile"]["Type"] == "INTERSECT"
assert len(res.docs) == 3 # check also the search result
+
+
+@pytest.mark.redismod
+def test_text_params(modclient):
+ modclient.flushdb()
+ modclient.ft().create_index((TextField("name"),))
+
+ modclient.ft().add_document("doc1", name="Alice")
+ modclient.ft().add_document("doc2", name="Bob")
+ modclient.ft().add_document("doc3", name="Carol")
+
+ params_dict = {"name1": "Alice", "name2": "Bob"}
+ q = Query("@name:($name1 | $name2 )")
+ res = modclient.ft().search(q, query_params=params_dict)
+ assert 2 == res.total
+ assert "doc1" == res.docs[0].id
+ assert "doc2" == res.docs[1].id
+
+
+@pytest.mark.redismod
+def test_numeric_params(modclient):
+ modclient.flushdb()
+ modclient.ft().create_index((NumericField("numval"),))
+
+ modclient.ft().add_document("doc1", numval=101)
+ modclient.ft().add_document("doc2", numval=102)
+ modclient.ft().add_document("doc3", numval=103)
+
+ params_dict = {"min": 101, "max": 102}
+ q = Query("@numval:[$min $max]")
+ res = modclient.ft().search(q, query_params=params_dict)
+
+ assert 2 == res.total
+ assert "doc1" == res.docs[0].id
+ assert "doc2" == res.docs[1].id
+
+
+@pytest.mark.redismod
+def test_geo_params(modclient):
+
+ modclient.flushdb()
+ modclient.ft().create_index((GeoField("g")))
+ modclient.ft().add_document("doc1", g="29.69465, 34.95126")
+ modclient.ft().add_document("doc2", g="29.69350, 34.94737")
+ modclient.ft().add_document("doc3", g="29.68746, 34.94882")
+
+ params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"}
+ q = Query("@g:[$lon $lat $radius $units]")
+ res = modclient.ft().search(q, query_params=params_dict)
+ assert 3 == res.total
+ assert "doc1" == res.docs[0].id
+ assert "doc2" == res.docs[1].id
+ assert "doc3" == res.docs[2].id