summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAvital Fine <79420960+AvitalFineRedis@users.noreply.github.com>2021-11-30 17:47:25 +0100
committerGitHub <noreply@github.com>2021-11-30 18:47:25 +0200
commit175a05f4de17918b74bde7f554182968b1f6aabb (patch)
tree6fe2b660c03c4e0342d1b0c6e124490baf14f7a0
parentb94e230b17d08e6c89d134e933c706256b79bc4a (diff)
downloadredis-py-175a05f4de17918b74bde7f554182968b1f6aabb.tar.gz
Adding RedisGraph support (#1673)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
-rw-r--r--redis/commands/graph/__init__.py162
-rw-r--r--redis/commands/graph/commands.py200
-rw-r--r--redis/commands/graph/edge.py87
-rw-r--r--redis/commands/graph/exceptions.py3
-rw-r--r--redis/commands/graph/node.py84
-rw-r--r--redis/commands/graph/path.py74
-rw-r--r--redis/commands/graph/query_result.py362
-rw-r--r--redis/commands/helpers.py38
-rw-r--r--redis/commands/redismodules.py10
-rw-r--r--setup.py1
-rw-r--r--tests/test_graph.py477
-rw-r--r--tests/test_graph_utils/__init__.py0
-rw-r--r--tests/test_graph_utils/test_edge.py77
-rw-r--r--tests/test_graph_utils/test_node.py52
-rw-r--r--tests/test_graph_utils/test_path.py91
-rw-r--r--tox.ini3
16 files changed, 1720 insertions, 1 deletions
diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py
new file mode 100644
index 0000000..7b9972a
--- /dev/null
+++ b/redis/commands/graph/__init__.py
@@ -0,0 +1,162 @@
+from ..helpers import quote_string, random_string, stringify_param_value
+from .commands import GraphCommands
+from .edge import Edge # noqa
+from .node import Node # noqa
+from .path import Path # noqa
+
+
+class Graph(GraphCommands):
+ """
+ Graph, collection of nodes and edges.
+ """
+
+ def __init__(self, client, name=random_string()):
+ """
+ Create a new graph.
+ """
+ self.NAME = name # Graph key
+ self.client = client
+ self.execute_command = client.execute_command
+
+ self.nodes = {}
+ self.edges = []
+ self._labels = [] # List of node labels.
+ self._properties = [] # List of properties.
+ self._relationshipTypes = [] # List of relation types.
+ self.version = 0 # Graph version
+
+ @property
+ def name(self):
+ return self.NAME
+
+ def _clear_schema(self):
+ self._labels = []
+ self._properties = []
+ self._relationshipTypes = []
+
+ def _refresh_schema(self):
+ self._clear_schema()
+ self._refresh_labels()
+ self._refresh_relations()
+ self._refresh_attributes()
+
+ def _refresh_labels(self):
+ lbls = self.labels()
+
+ # Unpack data.
+ self._labels = [None] * len(lbls)
+ for i, l in enumerate(lbls):
+ self._labels[i] = l[0]
+
+ def _refresh_relations(self):
+ rels = self.relationshipTypes()
+
+ # Unpack data.
+ self._relationshipTypes = [None] * len(rels)
+ for i, r in enumerate(rels):
+ self._relationshipTypes[i] = r[0]
+
+ def _refresh_attributes(self):
+ props = self.propertyKeys()
+
+ # Unpack data.
+ self._properties = [None] * len(props)
+ for i, p in enumerate(props):
+ self._properties[i] = p[0]
+
+ def get_label(self, idx):
+ """
+ Returns a label by it's index
+
+ Args:
+
+ idx:
+ The index of the label
+ """
+ try:
+ label = self._labels[idx]
+ except IndexError:
+ # Refresh labels.
+ self._refresh_labels()
+ label = self._labels[idx]
+ return label
+
+ def get_relation(self, idx):
+ """
+ Returns a relationship type by it's index
+
+ Args:
+
+ idx:
+ The index of the relation
+ """
+ try:
+ relationship_type = self._relationshipTypes[idx]
+ except IndexError:
+ # Refresh relationship types.
+ self._refresh_relations()
+ relationship_type = self._relationshipTypes[idx]
+ return relationship_type
+
+ def get_property(self, idx):
+ """
+ Returns a property by it's index
+
+ Args:
+
+ idx:
+ The index of the property
+ """
+ try:
+ propertie = self._properties[idx]
+ except IndexError:
+ # Refresh properties.
+ self._refresh_attributes()
+ propertie = self._properties[idx]
+ return propertie
+
+ def add_node(self, node):
+ """
+ Adds a node to the graph.
+ """
+ if node.alias is None:
+ node.alias = random_string()
+ self.nodes[node.alias] = node
+
+ def add_edge(self, edge):
+ """
+ Adds an edge to the graph.
+ """
+ if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]):
+ raise AssertionError("Both edge's end must be in the graph")
+
+ self.edges.append(edge)
+
+ def _build_params_header(self, params):
+ if not isinstance(params, dict):
+ raise TypeError("'params' must be a dict")
+ # Header starts with "CYPHER"
+ params_header = "CYPHER "
+ for key, value in params.items():
+ params_header += str(key) + "=" + stringify_param_value(value) + " "
+ return params_header
+
+ # Procedures.
+ def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
+ args = [quote_string(arg) for arg in args]
+ q = f"CALL {procedure}({','.join(args)})"
+
+ y = kwagrs.get("y", None)
+ if y:
+ q += f" YIELD {','.join(y)}"
+
+ return self.query(q, read_only=read_only)
+
+ def labels(self):
+ return self.call_procedure("db.labels", read_only=True).result_set
+
+ def relationshipTypes(self):
+ return self.call_procedure("db.relationshipTypes", read_only=True).result_set
+
+ def propertyKeys(self):
+ return self.call_procedure("db.propertyKeys", read_only=True).result_set
diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py
new file mode 100644
index 0000000..f0c1d68
--- /dev/null
+++ b/redis/commands/graph/commands.py
@@ -0,0 +1,200 @@
+from redis import DataError
+from redis.exceptions import ResponseError
+
+from .exceptions import VersionMismatchException
+from .query_result import QueryResult
+
+
+class GraphCommands:
+ def commit(self):
+ """
+ Create entire graph.
+ For more information see `CREATE <https://oss.redis.com/redisgraph/master/commands/#create>`_. # noqa
+ """
+ if len(self.nodes) == 0 and len(self.edges) == 0:
+ return None
+
+ query = "CREATE "
+ for _, node in self.nodes.items():
+ query += str(node) + ","
+
+ query += ",".join([str(edge) for edge in self.edges])
+
+ # Discard leading comma.
+ if query[-1] == ",":
+ query = query[:-1]
+
+ return self.query(query)
+
+ def query(self, q, params=None, timeout=None, read_only=False, profile=False):
+ """
+ Executes a query against the graph.
+ For more information see `GRAPH.QUERY <https://oss.redis.com/redisgraph/master/commands/#graphquery>`_. # noqa
+
+ Args:
+
+ -------
+ q :
+ The query.
+ params : dict
+ Query parameters.
+ timeout : int
+ Maximum runtime for read queries in milliseconds.
+ read_only : bool
+ Executes a readonly query if set to True.
+ profile : bool
+ Return details on results produced by and time
+ spent in each operation.
+ """
+
+ # maintain original 'q'
+ query = q
+
+ # handle query parameters
+ if params is not None:
+ query = self._build_params_header(params) + query
+
+ # construct query command
+ # ask for compact result-set format
+ # specify known graph version
+ if profile:
+ cmd = "GRAPH.PROFILE"
+ else:
+ cmd = "GRAPH.RO_QUERY" if read_only else "GRAPH.QUERY"
+ command = [cmd, self.name, query, "--compact"]
+
+ # include timeout is specified
+ if timeout:
+ if not isinstance(timeout, int):
+ raise Exception("Timeout argument must be a positive integer")
+ command += ["timeout", timeout]
+
+ # issue query
+ try:
+ response = self.execute_command(*command)
+ return QueryResult(self, response, profile)
+ except ResponseError as e:
+ if "wrong number of arguments" in str(e):
+ print(
+ "Note: RedisGraph Python requires server version 2.2.8 or above"
+ ) # noqa
+ if "unknown command" in str(e) and read_only:
+ # `GRAPH.RO_QUERY` is unavailable in older versions.
+ return self.query(q, params, timeout, read_only=False)
+ raise e
+ except VersionMismatchException as e:
+ # client view over the graph schema is out of sync
+ # set client version and refresh local schema
+ self.version = e.version
+ self._refresh_schema()
+ # re-issue query
+ return self.query(q, params, timeout, read_only)
+
+ def merge(self, pattern):
+ """
+ Merge pattern.
+ For more information see `MERGE <https://oss.redis.com/redisgraph/master/commands/#merge>`_. # noqa
+ """
+ query = "MERGE "
+ query += str(pattern)
+
+ return self.query(query)
+
+ def delete(self):
+ """
+ Deletes graph.
+ For more information see `DELETE <https://oss.redis.com/redisgraph/master/commands/#delete>`_. # noqa
+ """
+ self._clear_schema()
+ return self.execute_command("GRAPH.DELETE", self.name)
+
+ # declared here, to override the built in redis.db.flush()
+ def flush(self):
+ """
+ Commit the graph and reset the edges and the nodes to zero length.
+ """
+ self.commit()
+ self.nodes = {}
+ self.edges = []
+
+ def explain(self, query, params=None):
+ """
+ Get the execution plan for given query,
+ Returns an array of operations.
+ For more information see `GRAPH.EXPLAIN <https://oss.redis.com/redisgraph/master/commands/#graphexplain>`_. # noqa
+
+ Args:
+
+ -------
+ query:
+ The query that will be executed.
+ params: dict
+ Query parameters.
+ """
+ if params is not None:
+ query = self._build_params_header(params) + query
+
+ plan = self.execute_command("GRAPH.EXPLAIN", self.name, query)
+ return "\n".join(plan)
+
+ def bulk(self, **kwargs):
+ """Internal only. Not supported."""
+ raise NotImplementedError(
+ "GRAPH.BULK is internal only. "
+ "Use https://github.com/redisgraph/redisgraph-bulk-loader."
+ )
+
+ def profile(self, query):
+ """
+ Execute a query and produce an execution plan augmented with metrics
+ for each operation's execution. Return a string representation of a
+ query execution plan, with details on results produced by and time
+ spent in each operation.
+ For more information see `GRAPH.PROFILE <https://oss.redis.com/redisgraph/master/commands/#graphprofile>`_. # noqa
+ """
+ return self.query(query, profile=True)
+
+ def slowlog(self):
+ """
+ Get a list containing up to 10 of the slowest queries issued
+ against the given graph ID.
+ For more information see `GRAPH.SLOWLOG <https://oss.redis.com/redisgraph/master/commands/#graphslowlog>`_. # noqa
+
+ Each item in the list has the following structure:
+ 1. A unix timestamp at which the log entry was processed.
+ 2. The issued command.
+ 3. The issued query.
+ 4. The amount of time needed for its execution, in milliseconds.
+ """
+ return self.execute_command("GRAPH.SLOWLOG", self.name)
+
+ def config(self, name, value=None, set=False):
+ """
+ Retrieve or update a RedisGraph configuration.
+ For more information see `GRAPH.CONFIG <https://oss.redis.com/redisgraph/master/commands/#graphconfig>`_. # noqa
+
+ Args:
+
+ name : str
+ The name of the configuration
+ value :
+ The value we want to ser (can be used only when `set` is on)
+ set : bool
+ Turn on to set a configuration. Default behavior is get.
+ """
+ params = ["SET" if set else "GET", name]
+ if value is not None:
+ if set:
+ params.append(value)
+ else:
+ raise DataError(
+ "``value`` can be provided only when ``set`` is True"
+ ) # noqa
+ return self.execute_command("GRAPH.CONFIG", *params)
+
+ def list_keys(self):
+ """
+ Lists all graph keys in the keyspace.
+ For more information see `GRAPH.LIST <https://oss.redis.com/redisgraph/master/commands/#graphlist>`_. # noqa
+ """
+ return self.execute_command("GRAPH.LIST")
diff --git a/redis/commands/graph/edge.py b/redis/commands/graph/edge.py
new file mode 100644
index 0000000..b334293
--- /dev/null
+++ b/redis/commands/graph/edge.py
@@ -0,0 +1,87 @@
+from ..helpers import quote_string
+from .node import Node
+
+
+class Edge:
+ """
+ An edge connecting two nodes.
+ """
+
+ def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None):
+ """
+ Create a new edge.
+ """
+ if src_node is None or dest_node is None:
+ # NOTE(bors-42): It makes sense to change AssertionError to
+ # ValueError here
+ raise AssertionError("Both src_node & dest_node must be provided")
+
+ self.id = edge_id
+ self.relation = relation or ""
+ self.properties = properties or {}
+ self.src_node = src_node
+ self.dest_node = dest_node
+
+ def toString(self):
+ res = ""
+ if self.properties:
+ props = ",".join(
+ key + ":" + str(quote_string(val))
+ for key, val in sorted(self.properties.items())
+ )
+ res += "{" + props + "}"
+
+ return res
+
+ def __str__(self):
+ # Source node.
+ if isinstance(self.src_node, Node):
+ res = str(self.src_node)
+ else:
+ res = "()"
+
+ # Edge
+ res += "-["
+ if self.relation:
+ res += ":" + self.relation
+ if self.properties:
+ props = ",".join(
+ key + ":" + str(quote_string(val))
+ for key, val in sorted(self.properties.items())
+ )
+ res += "{" + props + "}"
+ res += "]->"
+
+ # Dest node.
+ if isinstance(self.dest_node, Node):
+ res += str(self.dest_node)
+ else:
+ res += "()"
+
+ return res
+
+ def __eq__(self, rhs):
+ # Quick positive check, if both IDs are set.
+ if self.id is not None and rhs.id is not None and self.id == rhs.id:
+ return True
+
+ # Source and destination nodes should match.
+ if self.src_node != rhs.src_node:
+ return False
+
+ if self.dest_node != rhs.dest_node:
+ return False
+
+ # Relation should match.
+ if self.relation != rhs.relation:
+ return False
+
+ # Quick check for number of properties.
+ if len(self.properties) != len(rhs.properties):
+ return False
+
+ # Compare properties.
+ if self.properties != rhs.properties:
+ return False
+
+ return True
diff --git a/redis/commands/graph/exceptions.py b/redis/commands/graph/exceptions.py
new file mode 100644
index 0000000..4bbac10
--- /dev/null
+++ b/redis/commands/graph/exceptions.py
@@ -0,0 +1,3 @@
+class VersionMismatchException(Exception):
+ def __init__(self, version):
+ self.version = version
diff --git a/redis/commands/graph/node.py b/redis/commands/graph/node.py
new file mode 100644
index 0000000..47e4eeb
--- /dev/null
+++ b/redis/commands/graph/node.py
@@ -0,0 +1,84 @@
+from ..helpers import quote_string
+
+
+class Node:
+ """
+ A node within the graph.
+ """
+
+ def __init__(self, node_id=None, alias=None, label=None, properties=None):
+ """
+ Create a new node.
+ """
+ self.id = node_id
+ self.alias = alias
+ if isinstance(label, list):
+ label = [inner_label for inner_label in label if inner_label != ""]
+
+ if (
+ label is None
+ or label == ""
+ or (isinstance(label, list) and len(label) == 0)
+ ):
+ self.label = None
+ self.labels = None
+ elif isinstance(label, str):
+ self.label = label
+ self.labels = [label]
+ elif isinstance(label, list) and all(
+ [isinstance(inner_label, str) for inner_label in label]
+ ):
+ self.label = label[0]
+ self.labels = label
+ else:
+ raise AssertionError(
+ "label should be either None, " "string or a list of strings"
+ )
+
+ self.properties = properties or {}
+
+ def toString(self):
+ res = ""
+ if self.properties:
+ props = ",".join(
+ key + ":" + str(quote_string(val))
+ for key, val in sorted(self.properties.items())
+ )
+ res += "{" + props + "}"
+
+ return res
+
+ def __str__(self):
+ res = "("
+ if self.alias:
+ res += self.alias
+ if self.labels:
+ res += ":" + ":".join(self.labels)
+ if self.properties:
+ props = ",".join(
+ key + ":" + str(quote_string(val))
+ for key, val in sorted(self.properties.items())
+ )
+ res += "{" + props + "}"
+ res += ")"
+
+ return res
+
+ def __eq__(self, rhs):
+ # Quick positive check, if both IDs are set.
+ if self.id is not None and rhs.id is not None and self.id != rhs.id:
+ return False
+
+ # Label should match.
+ if self.label != rhs.label:
+ return False
+
+ # Quick check for number of properties.
+ if len(self.properties) != len(rhs.properties):
+ return False
+
+ # Compare properties.
+ if self.properties != rhs.properties:
+ return False
+
+ return True
diff --git a/redis/commands/graph/path.py b/redis/commands/graph/path.py
new file mode 100644
index 0000000..6f2214a
--- /dev/null
+++ b/redis/commands/graph/path.py
@@ -0,0 +1,74 @@
+from .edge import Edge
+from .node import Node
+
+
+class Path:
+ def __init__(self, nodes, edges):
+ if not (isinstance(nodes, list) and isinstance(edges, list)):
+ raise TypeError("nodes and edges must be list")
+
+ self._nodes = nodes
+ self._edges = edges
+ self.append_type = Node
+
+ @classmethod
+ def new_empty_path(cls):
+ return cls([], [])
+
+ def nodes(self):
+ return self._nodes
+
+ def edges(self):
+ return self._edges
+
+ def get_node(self, index):
+ return self._nodes[index]
+
+ def get_relationship(self, index):
+ return self._edges[index]
+
+ def first_node(self):
+ return self._nodes[0]
+
+ def last_node(self):
+ return self._nodes[-1]
+
+ def edge_count(self):
+ return len(self._edges)
+
+ def nodes_count(self):
+ return len(self._nodes)
+
+ def add_node(self, node):
+ if not isinstance(node, self.append_type):
+ raise AssertionError("Add Edge before adding Node")
+ self._nodes.append(node)
+ self.append_type = Edge
+ return self
+
+ def add_edge(self, edge):
+ if not isinstance(edge, self.append_type):
+ raise AssertionError("Add Node before adding Edge")
+ self._edges.append(edge)
+ self.append_type = Node
+ return self
+
+ def __eq__(self, other):
+ return self.nodes() == other.nodes() and self.edges() == other.edges()
+
+ def __str__(self):
+ res = "<"
+ edge_count = self.edge_count()
+ for i in range(0, edge_count):
+ node_id = self.get_node(i).id
+ res += "(" + str(node_id) + ")"
+ edge = self.get_relationship(i)
+ res += (
+ "-[" + str(int(edge.id)) + "]->"
+ if edge.src_node == node_id
+ else "<-[" + str(int(edge.id)) + "]-"
+ )
+ node_id = self.get_node(edge_count).id
+ res += "(" + str(node_id) + ")"
+ res += ">"
+ return res
diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py
new file mode 100644
index 0000000..e9d9f4d
--- /dev/null
+++ b/redis/commands/graph/query_result.py
@@ -0,0 +1,362 @@
+from collections import OrderedDict
+
+# from prettytable import PrettyTable
+from redis import ResponseError
+
+from .edge import Edge
+from .exceptions import VersionMismatchException
+from .node import Node
+from .path import Path
+
+LABELS_ADDED = "Labels added"
+NODES_CREATED = "Nodes created"
+NODES_DELETED = "Nodes deleted"
+RELATIONSHIPS_DELETED = "Relationships deleted"
+PROPERTIES_SET = "Properties set"
+RELATIONSHIPS_CREATED = "Relationships created"
+INDICES_CREATED = "Indices created"
+INDICES_DELETED = "Indices deleted"
+CACHED_EXECUTION = "Cached execution"
+INTERNAL_EXECUTION_TIME = "internal execution time"
+
+STATS = [
+ LABELS_ADDED,
+ NODES_CREATED,
+ PROPERTIES_SET,
+ RELATIONSHIPS_CREATED,
+ NODES_DELETED,
+ RELATIONSHIPS_DELETED,
+ INDICES_CREATED,
+ INDICES_DELETED,
+ CACHED_EXECUTION,
+ INTERNAL_EXECUTION_TIME,
+]
+
+
+class ResultSetColumnTypes:
+ COLUMN_UNKNOWN = 0
+ COLUMN_SCALAR = 1
+ COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
+ COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
+
+
+class ResultSetScalarTypes:
+ VALUE_UNKNOWN = 0
+ VALUE_NULL = 1
+ VALUE_STRING = 2
+ VALUE_INTEGER = 3
+ VALUE_BOOLEAN = 4
+ VALUE_DOUBLE = 5
+ VALUE_ARRAY = 6
+ VALUE_EDGE = 7
+ VALUE_NODE = 8
+ VALUE_PATH = 9
+ VALUE_MAP = 10
+ VALUE_POINT = 11
+
+
+class QueryResult:
+ def __init__(self, graph, response, profile=False):
+ """
+ A class that represents a result of the query operation.
+
+ Args:
+
+ graph:
+ The graph on which the query was executed.
+ response:
+ The response from the server.
+ profile:
+ A boolean indicating if the query command was "GRAPH.PROFILE"
+ """
+ self.graph = graph
+ self.header = []
+ self.result_set = []
+
+ # in case of an error an exception will be raised
+ self._check_for_errors(response)
+
+ if len(response) == 1:
+ self.parse_statistics(response[0])
+ elif profile:
+ self.parse_profile(response)
+ else:
+ # start by parsing statistics, matches the one we have
+ self.parse_statistics(response[-1]) # Last element.
+ self.parse_results(response)
+
+ def _check_for_errors(self, response):
+ if isinstance(response[0], ResponseError):
+ error = response[0]
+ if str(error) == "version mismatch":
+ version = response[1]
+ error = VersionMismatchException(version)
+ raise error
+
+ # If we encountered a run-time error, the last response
+ # element will be an exception
+ if isinstance(response[-1], ResponseError):
+ raise response[-1]
+
+ def parse_results(self, raw_result_set):
+ self.header = self.parse_header(raw_result_set)
+
+ # Empty header.
+ if len(self.header) == 0:
+ return
+
+ self.result_set = self.parse_records(raw_result_set)
+
+ def parse_statistics(self, raw_statistics):
+ self.statistics = {}
+
+ # decode statistics
+ for idx, stat in enumerate(raw_statistics):
+ if isinstance(stat, bytes):
+ raw_statistics[idx] = stat.decode()
+
+ for s in STATS:
+ v = self._get_value(s, raw_statistics)
+ if v is not None:
+ self.statistics[s] = v
+
+ def parse_header(self, raw_result_set):
+ # An array of column name/column type pairs.
+ header = raw_result_set[0]
+ return header
+
+ def parse_records(self, raw_result_set):
+ records = []
+ result_set = raw_result_set[1]
+ for row in result_set:
+ record = []
+ for idx, cell in enumerate(row):
+ if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa
+ record.append(self.parse_scalar(cell))
+ elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa
+ record.append(self.parse_node(cell))
+ elif (
+ self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION
+ ): # noqa
+ record.append(self.parse_edge(cell))
+ else:
+ print("Unknown column type.\n")
+ records.append(record)
+
+ return records
+
+ def parse_entity_properties(self, props):
+ # [[name, value type, value] X N]
+ properties = {}
+ for prop in props:
+ prop_name = self.graph.get_property(prop[0])
+ prop_value = self.parse_scalar(prop[1:])
+ properties[prop_name] = prop_value
+
+ return properties
+
+ def parse_string(self, cell):
+ if isinstance(cell, bytes):
+ return cell.decode()
+ elif not isinstance(cell, str):
+ return str(cell)
+ else:
+ return cell
+
+ def parse_node(self, cell):
+ # Node ID (integer),
+ # [label string offset (integer)],
+ # [[name, value type, value] X N]
+
+ node_id = int(cell[0])
+ labels = None
+ if len(cell[1]) > 0:
+ labels = []
+ for inner_label in cell[1]:
+ labels.append(self.graph.get_label(inner_label))
+ properties = self.parse_entity_properties(cell[2])
+ return Node(node_id=node_id, label=labels, properties=properties)
+
+ def parse_edge(self, cell):
+ # Edge ID (integer),
+ # reltype string offset (integer),
+ # src node ID offset (integer),
+ # dest node ID offset (integer),
+ # [[name, value, value type] X N]
+
+ edge_id = int(cell[0])
+ relation = self.graph.get_relation(cell[1])
+ src_node_id = int(cell[2])
+ dest_node_id = int(cell[3])
+ properties = self.parse_entity_properties(cell[4])
+ return Edge(
+ src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
+ )
+
+ def parse_path(self, cell):
+ nodes = self.parse_scalar(cell[0])
+ edges = self.parse_scalar(cell[1])
+ return Path(nodes, edges)
+
+ def parse_map(self, cell):
+ m = OrderedDict()
+ n_entries = len(cell)
+
+ # A map is an array of key value pairs.
+ # 1. key (string)
+ # 2. array: (value type, value)
+ for i in range(0, n_entries, 2):
+ key = self.parse_string(cell[i])
+ m[key] = self.parse_scalar(cell[i + 1])
+
+ return m
+
+ def parse_point(self, cell):
+ p = {}
+ # A point is received an array of the form: [latitude, longitude]
+ # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa
+ p["latitude"] = float(cell[0])
+ p["longitude"] = float(cell[1])
+ return p
+
+ def parse_scalar(self, cell):
+ scalar_type = int(cell[0])
+ value = cell[1]
+ scalar = None
+
+ if scalar_type == ResultSetScalarTypes.VALUE_NULL:
+ scalar = None
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_STRING:
+ scalar = self.parse_string(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER:
+ scalar = int(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN:
+ value = value.decode() if isinstance(value, bytes) else value
+ if value == "true":
+ scalar = True
+ elif value == "false":
+ scalar = False
+ else:
+ print("Unknown boolean type\n")
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE:
+ scalar = float(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY:
+ # array variable is introduced only for readability
+ scalar = array = value
+ for i in range(len(array)):
+ scalar[i] = self.parse_scalar(array[i])
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_NODE:
+ scalar = self.parse_node(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_EDGE:
+ scalar = self.parse_edge(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_PATH:
+ scalar = self.parse_path(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_MAP:
+ scalar = self.parse_map(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_POINT:
+ scalar = self.parse_point(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN:
+ print("Unknown scalar type\n")
+
+ return scalar
+
+ def parse_profile(self, response):
+ self.result_set = [x[0 : x.index(",")].strip() for x in response]
+
+ # """Prints the data from the query response:
+ # 1. First row result_set contains the columns names.
+ # Thus the first row in PrettyTable will contain the
+ # columns.
+ # 2. The row after that will contain the data returned,
+ # or 'No Data returned' if there is none.
+ # 3. Prints the statistics of the query.
+ # """
+
+ # def pretty_print(self):
+ # if not self.is_empty():
+ # header = [col[1] for col in self.header]
+ # tbl = PrettyTable(header)
+
+ # for row in self.result_set:
+ # record = []
+ # for idx, cell in enumerate(row):
+ # if type(cell) is Node:
+ # record.append(cell.toString())
+ # elif type(cell) is Edge:
+ # record.append(cell.toString())
+ # else:
+ # record.append(cell)
+ # tbl.add_row(record)
+
+ # if len(self.result_set) == 0:
+ # tbl.add_row(['No data returned.'])
+
+ # print(str(tbl) + '\n')
+
+ # for stat in self.statistics:
+ # print("%s %s" % (stat, self.statistics[stat]))
+
+ def is_empty(self):
+ return len(self.result_set) == 0
+
+ @staticmethod
+ def _get_value(prop, statistics):
+ for stat in statistics:
+ if prop in stat:
+ return float(stat.split(": ")[1].split(" ")[0])
+
+ return None
+
+ def _get_stat(self, stat):
+ return self.statistics[stat] if stat in self.statistics else 0
+
+ @property
+ def labels_added(self):
+ return self._get_stat(LABELS_ADDED)
+
+ @property
+ def nodes_created(self):
+ return self._get_stat(NODES_CREATED)
+
+ @property
+ def nodes_deleted(self):
+ return self._get_stat(NODES_DELETED)
+
+ @property
+ def properties_set(self):
+ return self._get_stat(PROPERTIES_SET)
+
+ @property
+ def relationships_created(self):
+ return self._get_stat(RELATIONSHIPS_CREATED)
+
+ @property
+ def relationships_deleted(self):
+ return self._get_stat(RELATIONSHIPS_DELETED)
+
+ @property
+ def indices_created(self):
+ return self._get_stat(INDICES_CREATED)
+
+ @property
+ def indices_deleted(self):
+ return self._get_stat(INDICES_DELETED)
+
+ @property
+ def cached_execution(self):
+ return self._get_stat(CACHED_EXECUTION) == 1
+
+ @property
+ def run_time_ms(self):
+ return self._get_stat(INTERNAL_EXECUTION_TIME)
diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py
index 80dfd76..afb4f9f 100644
--- a/redis/commands/helpers.py
+++ b/redis/commands/helpers.py
@@ -1,3 +1,4 @@
+import copy
import random
import string
@@ -114,3 +115,40 @@ def quote_string(v):
v = v.replace('"', '\\"')
return f'"{v}"'
+
+
+def decodeDictKeys(obj):
+ """Decode the keys of the given dictionary with utf-8."""
+ newobj = copy.copy(obj)
+ for k in obj.keys():
+ if isinstance(k, bytes):
+ newobj[k.decode("utf-8")] = newobj[k]
+ newobj.pop(k)
+ return newobj
+
+
+def stringify_param_value(value):
+ """
+ Turn a parameter value into a string suitable for the params header of
+ a Cypher command.
+ You may pass any value that would be accepted by `json.dumps()`.
+
+ Ways in which output differs from that of `str()`:
+ * Strings are quoted.
+ * None --> "null".
+ * In dictionaries, keys are _not_ quoted.
+
+ :param value: The parameter value to be turned into a string.
+ :return: string
+ """
+
+ if isinstance(value, str):
+ return quote_string(value)
+ elif value is None:
+ return "null"
+ elif isinstance(value, (list, tuple)):
+ return f'[{",".join(map(stringify_param_value, value))}]'
+ elif isinstance(value, dict):
+ return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa
+ else:
+ return str(value)
diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py
index 2420d7b..e5ace63 100644
--- a/redis/commands/redismodules.py
+++ b/redis/commands/redismodules.py
@@ -31,3 +31,13 @@ class RedisModuleCommands:
s = TimeSeries(client=self)
return s
+
+ def graph(self, index_name="idx"):
+ """Access the timeseries namespace, providing support for
+ redis timeseries data.
+ """
+
+ from .graph import Graph
+
+ g = Graph(client=self, name=index_name)
+ return g
diff --git a/setup.py b/setup.py
index ee91298..d830801 100644
--- a/setup.py
+++ b/setup.py
@@ -18,6 +18,7 @@ setup(
"redis.commands.json",
"redis.commands.search",
"redis.commands.timeseries",
+ "redis.commands.graph",
]
),
url="https://github.com/redis/redis-py",
diff --git a/tests/test_graph.py b/tests/test_graph.py
new file mode 100644
index 0000000..c6dc9a4
--- /dev/null
+++ b/tests/test_graph.py
@@ -0,0 +1,477 @@
+import pytest
+
+from redis.commands.graph import Edge, Node, Path
+from redis.exceptions import ResponseError
+
+
+@pytest.fixture
+def client(modclient):
+ modclient.flushdb()
+ return modclient
+
+
+@pytest.mark.redismod
+def test_bulk(client):
+ with pytest.raises(NotImplementedError):
+ client.graph().bulk()
+ client.graph().bulk(foo="bar!")
+
+
+@pytest.mark.redismod
+def test_graph_creation(client):
+ graph = client.graph()
+
+ john = Node(
+ label="person",
+ properties={
+ "name": "John Doe",
+ "age": 33,
+ "gender": "male",
+ "status": "single",
+ },
+ )
+ graph.add_node(john)
+ japan = Node(label="country", properties={"name": "Japan"})
+
+ graph.add_node(japan)
+ edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"})
+ graph.add_edge(edge)
+
+ graph.commit()
+
+ query = (
+ 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) '
+ "RETURN p, v, c"
+ )
+
+ result = graph.query(query)
+
+ person = result.result_set[0][0]
+ visit = result.result_set[0][1]
+ country = result.result_set[0][2]
+
+ assert person == john
+ assert visit.properties == edge.properties
+ assert country == japan
+
+ query = """RETURN [1, 2.3, "4", true, false, null]"""
+ result = graph.query(query)
+ assert [1, 2.3, "4", True, False, None] == result.result_set[0][0]
+
+ # All done, remove graph.
+ graph.delete()
+
+
+@pytest.mark.redismod
+def test_array_functions(client):
+ query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})"""
+ client.graph().query(query)
+
+ query = """WITH [0,1,2] as x return x"""
+ result = client.graph().query(query)
+ assert [0, 1, 2] == result.result_set[0][0]
+
+ query = """MATCH(n) return collect(n)"""
+ result = client.graph().query(query)
+
+ a = Node(
+ node_id=0,
+ label="person",
+ properties={"name": "a", "age": 32, "array": [0, 1, 2]},
+ )
+
+ assert [a] == result.result_set[0][0]
+
+
+@pytest.mark.redismod
+def test_path(client):
+ node0 = Node(node_id=0, label="L1")
+ node1 = Node(node_id=1, label="L1")
+ edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1})
+
+ graph = client.graph()
+ graph.add_node(node0)
+ graph.add_node(node1)
+ graph.add_edge(edge01)
+ graph.flush()
+
+ path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1)
+ expected_results = [[path01]]
+
+ query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p"
+ result = graph.query(query)
+ assert expected_results == result.result_set
+
+
+@pytest.mark.redismod
+def test_param(client):
+ params = [1, 2.3, "str", True, False, None, [0, 1, 2]]
+ query = "RETURN $param"
+ for param in params:
+ result = client.graph().query(query, {"param": param})
+ expected_results = [[param]]
+ assert expected_results == result.result_set
+
+
+@pytest.mark.redismod
+def test_map(client):
+ query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}"
+
+ actual = client.graph().query(query).result_set[0][0]
+ expected = {
+ "a": 1,
+ "b": "str",
+ "c": None,
+ "d": [1, 2, 3],
+ "e": True,
+ "f": {"x": 1, "y": 2},
+ }
+
+ assert actual == expected
+
+
+@pytest.mark.redismod
+def test_point(client):
+ query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})"
+ expected_lat = 32.070794860
+ expected_lon = 34.820751118
+ actual = client.graph().query(query).result_set[0][0]
+ assert abs(actual["latitude"] - expected_lat) < 0.001
+ assert abs(actual["longitude"] - expected_lon) < 0.001
+
+ query = "RETURN point({latitude: 32, longitude: 34.0})"
+ expected_lat = 32
+ expected_lon = 34
+ actual = client.graph().query(query).result_set[0][0]
+ assert abs(actual["latitude"] - expected_lat) < 0.001
+ assert abs(actual["longitude"] - expected_lon) < 0.001
+
+
+@pytest.mark.redismod
+def test_index_response(client):
+ result_set = client.graph().query("CREATE INDEX ON :person(age)")
+ assert 1 == result_set.indices_created
+
+ result_set = client.graph().query("CREATE INDEX ON :person(age)")
+ assert 0 == result_set.indices_created
+
+ result_set = client.graph().query("DROP INDEX ON :person(age)")
+ assert 1 == result_set.indices_deleted
+
+ with pytest.raises(ResponseError):
+ client.graph().query("DROP INDEX ON :person(age)")
+
+
+@pytest.mark.redismod
+def test_stringify_query_result(client):
+ graph = client.graph()
+
+ john = Node(
+ alias="a",
+ label="person",
+ properties={
+ "name": "John Doe",
+ "age": 33,
+ "gender": "male",
+ "status": "single",
+ },
+ )
+ graph.add_node(john)
+
+ japan = Node(alias="b", label="country", properties={"name": "Japan"})
+ graph.add_node(japan)
+
+ edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"})
+ graph.add_edge(edge)
+
+ assert (
+ str(john)
+ == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa
+ )
+ assert (
+ str(edge)
+ == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa
+ + """-[:visited{purpose:"pleasure"}]->"""
+ + """(b:country{name:"Japan"})"""
+ )
+ assert str(japan) == """(b:country{name:"Japan"})"""
+
+ graph.commit()
+
+ query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country)
+ RETURN p, v, c"""
+
+ result = client.graph().query(query)
+ person = result.result_set[0][0]
+ visit = result.result_set[0][1]
+ country = result.result_set[0][2]
+
+ assert (
+ str(person)
+ == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa
+ )
+ assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()"""
+ assert str(country) == """(:country{name:"Japan"})"""
+
+ graph.delete()
+
+
+@pytest.mark.redismod
+def test_optional_match(client):
+ # Build a graph of form (a)-[R]->(b)
+ node0 = Node(node_id=0, label="L1", properties={"value": "a"})
+ node1 = Node(node_id=1, label="L1", properties={"value": "b"})
+
+ edge01 = Edge(node0, "R", node1, edge_id=0)
+
+ graph = client.graph()
+ graph.add_node(node0)
+ graph.add_node(node1)
+ graph.add_edge(edge01)
+ graph.flush()
+
+ # Issue a query that collects all outgoing edges from both nodes
+ # (the second has none)
+ query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa
+ expected_results = [[node0, edge01, node1], [node1, None, None]]
+
+ result = client.graph().query(query)
+ assert expected_results == result.result_set
+
+ graph.delete()
+
+
+@pytest.mark.redismod
+def test_cached_execution(client):
+ client.graph().query("CREATE ()")
+
+ uncached_result = client.graph().query("MATCH (n) RETURN n, $param", {"param": [0]})
+ assert uncached_result.cached_execution is False
+
+ # loop to make sure the query is cached on each thread on server
+ for x in range(0, 64):
+ cached_result = client.graph().query(
+ "MATCH (n) RETURN n, $param", {"param": [0]}
+ )
+ assert uncached_result.result_set == cached_result.result_set
+
+ # should be cached on all threads by now
+ assert cached_result.cached_execution
+
+
+@pytest.mark.redismod
+def test_explain(client):
+ create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
+ (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
+ (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})"""
+ client.graph().query(create_query)
+
+ result = client.graph().explain(
+ "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa
+ {"name": "Yehuda"},
+ )
+ expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa
+ assert result == expected
+
+
+@pytest.mark.redismod
+def test_slowlog(client):
+ create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}),
+ (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}),
+ (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})"""
+ client.graph().query(create_query)
+
+ results = client.graph().slowlog()
+ assert results[0][1] == "GRAPH.QUERY"
+ assert results[0][2] == create_query
+
+
+@pytest.mark.redismod
+def test_query_timeout(client):
+ # Build a sample graph with 1000 nodes.
+ client.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})")
+ # Issue a long-running query with a 1-millisecond timeout.
+ with pytest.raises(ResponseError):
+ client.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1)
+ assert False is False
+
+ with pytest.raises(Exception):
+ client.graph().query("RETURN 1", timeout="str")
+ assert False is False
+
+
+@pytest.mark.redismod
+def test_read_only_query(client):
+ with pytest.raises(Exception):
+ # Issue a write query, specifying read-only true,
+ # this call should fail.
+ client.graph().query("CREATE (p:person {name:'a'})", read_only=True)
+ assert False is False
+
+
+@pytest.mark.redismod
+def test_profile(client):
+ q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})"""
+ profile = client.graph().profile(q).result_set
+ assert "Create | Records produced: 3" in profile
+ assert "Unwind | Records produced: 3" in profile
+
+ q = "MATCH (p:Person) WHERE p.v > 1 RETURN p"
+ profile = client.graph().profile(q).result_set
+ assert "Results | Records produced: 2" in profile
+ assert "Project | Records produced: 2" in profile
+ assert "Filter | Records produced: 2" in profile
+ assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile
+
+
+@pytest.mark.redismod
+def test_config(client):
+ config_name = "RESULTSET_SIZE"
+ config_value = 3
+
+ # Set configuration
+ response = client.graph().config(config_name, config_value, set=True)
+ assert response == "OK"
+
+ # Make sure config been updated.
+ response = client.graph().config(config_name, set=False)
+ expected_response = [config_name, config_value]
+ assert response == expected_response
+
+ config_name = "QUERY_MEM_CAPACITY"
+ config_value = 1 << 20 # 1MB
+
+ # Set configuration
+ response = client.graph().config(config_name, config_value, set=True)
+ assert response == "OK"
+
+ # Make sure config been updated.
+ response = client.graph().config(config_name, set=False)
+ expected_response = [config_name, config_value]
+ assert response == expected_response
+
+ # reset to default
+ client.graph().config("QUERY_MEM_CAPACITY", 0, set=True)
+ client.graph().config("RESULTSET_SIZE", -100, set=True)
+
+
+@pytest.mark.redismod
+def test_list_keys(client):
+ result = client.graph().list_keys()
+ assert result == []
+
+ client.execute_command("GRAPH.EXPLAIN", "G", "RETURN 1")
+ result = client.graph().list_keys()
+ assert result == ["G"]
+
+ client.execute_command("GRAPH.EXPLAIN", "X", "RETURN 1")
+ result = client.graph().list_keys()
+ assert result == ["G", "X"]
+
+ client.delete("G")
+ client.rename("X", "Z")
+ result = client.graph().list_keys()
+ assert result == ["Z"]
+
+ client.delete("Z")
+ result = client.graph().list_keys()
+ assert result == []
+
+
+@pytest.mark.redismod
+def test_multi_label(client):
+ redis_graph = client.graph("g")
+
+ node = Node(label=["l", "ll"])
+ redis_graph.add_node(node)
+ redis_graph.commit()
+
+ query = "MATCH (n) RETURN n"
+ result = redis_graph.query(query)
+ result_node = result.result_set[0][0]
+ assert result_node == node
+
+ try:
+ Node(label=1)
+ assert False
+ except AssertionError:
+ assert True
+
+ try:
+ Node(label=["l", 1])
+ assert False
+ except AssertionError:
+ assert True
+
+
+@pytest.mark.redismod
+def test_cache_sync(client):
+ pass
+ return
+ # This test verifies that client internal graph schema cache stays
+ # in sync with the graph schema
+ #
+ # Client B will try to get Client A out of sync by:
+ # 1. deleting the graph
+ # 2. reconstructing the graph in a different order, this will casuse
+ # a differance in the current mapping between string IDs and the
+ # mapping Client A is aware of
+ #
+ # Client A should pick up on the changes by comparing graph versions
+ # and resyncing its cache.
+
+ A = client.graph("cache-sync")
+ B = client.graph("cache-sync")
+
+ # Build order:
+ # 1. introduce label 'L' and 'K'
+ # 2. introduce attribute 'x' and 'q'
+ # 3. introduce relationship-type 'R' and 'S'
+
+ A.query("CREATE (:L)")
+ B.query("CREATE (:K)")
+ A.query("MATCH (n) SET n.x = 1")
+ B.query("MATCH (n) SET n.q = 1")
+ A.query("MATCH (n) CREATE (n)-[:R]->()")
+ B.query("MATCH (n) CREATE (n)-[:S]->()")
+
+ # Cause client A to populate its cache
+ A.query("MATCH (n)-[e]->() RETURN n, e")
+
+ assert len(A._labels) == 2
+ assert len(A._properties) == 2
+ assert len(A._relationshipTypes) == 2
+ assert A._labels[0] == "L"
+ assert A._labels[1] == "K"
+ assert A._properties[0] == "x"
+ assert A._properties[1] == "q"
+ assert A._relationshipTypes[0] == "R"
+ assert A._relationshipTypes[1] == "S"
+
+ # Have client B reconstruct the graph in a different order.
+ B.delete()
+
+ # Build order:
+ # 1. introduce relationship-type 'R'
+ # 2. introduce label 'L'
+ # 3. introduce attribute 'x'
+ B.query("CREATE ()-[:S]->()")
+ B.query("CREATE ()-[:R]->()")
+ B.query("CREATE (:K)")
+ B.query("CREATE (:L)")
+ B.query("MATCH (n) SET n.q = 1")
+ B.query("MATCH (n) SET n.x = 1")
+
+ # A's internal cached mapping is now out of sync
+ # issue a query and make sure A's cache is synced.
+ A.query("MATCH (n)-[e]->() RETURN n, e")
+
+ assert len(A._labels) == 2
+ assert len(A._properties) == 2
+ assert len(A._relationshipTypes) == 2
+ assert A._labels[0] == "K"
+ assert A._labels[1] == "L"
+ assert A._properties[0] == "q"
+ assert A._properties[1] == "x"
+ assert A._relationshipTypes[0] == "S"
+ assert A._relationshipTypes[1] == "R"
diff --git a/tests/test_graph_utils/__init__.py b/tests/test_graph_utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/test_graph_utils/__init__.py
diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py
new file mode 100644
index 0000000..42358de
--- /dev/null
+++ b/tests/test_graph_utils/test_edge.py
@@ -0,0 +1,77 @@
+import pytest
+
+from redis.commands.graph import edge, node
+
+
+@pytest.mark.redismod
+def test_init():
+
+ with pytest.raises(AssertionError):
+ edge.Edge(None, None, None)
+ edge.Edge(node.Node(), None, None)
+ edge.Edge(None, None, node.Node())
+
+ assert isinstance(
+ edge.Edge(node.Node(node_id=1), None, node.Node(node_id=2)), edge.Edge
+ )
+
+
+@pytest.mark.redismod
+def test_toString():
+ props_result = edge.Edge(
+ node.Node(), None, node.Node(), properties={"a": "a", "b": 10}
+ ).toString()
+ assert props_result == '{a:"a",b:10}'
+
+ no_props_result = edge.Edge(
+ node.Node(), None, node.Node(), properties={}
+ ).toString()
+ assert no_props_result == ""
+
+
+@pytest.mark.redismod
+def test_stringify():
+ john = node.Node(
+ alias="a",
+ label="person",
+ properties={"name": "John Doe", "age": 33, "someArray": [1, 2, 3]},
+ )
+ japan = node.Node(alias="b", label="country", properties={"name": "Japan"})
+ edge_with_relation = edge.Edge(
+ john, "visited", japan, properties={"purpose": "pleasure"}
+ )
+ assert (
+ '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})'
+ '-[:visited{purpose:"pleasure"}]->'
+ '(b:country{name:"Japan"})' == str(edge_with_relation)
+ )
+
+ edge_no_relation_no_props = edge.Edge(japan, "", john)
+ assert (
+ '(b:country{name:"Japan"})'
+ "-[]->"
+ '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})'
+ == str(edge_no_relation_no_props)
+ )
+
+ edge_only_props = edge.Edge(john, "", japan, properties={"a": "b", "c": 3})
+ assert (
+ '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})'
+ '-[{a:"b",c:3}]->'
+ '(b:country{name:"Japan"})' == str(edge_only_props)
+ )
+
+
+@pytest.mark.redismod
+def test_comparision():
+ node1 = node.Node(node_id=1)
+ node2 = node.Node(node_id=2)
+ node3 = node.Node(node_id=3)
+
+ edge1 = edge.Edge(node1, None, node2)
+ assert edge1 == edge.Edge(node1, None, node2)
+ assert edge1 != edge.Edge(node1, "bla", node2)
+ assert edge1 != edge.Edge(node1, None, node3)
+ assert edge1 != edge.Edge(node3, None, node2)
+ assert edge1 != edge.Edge(node2, None, node1)
+ assert edge1 != edge.Edge(node1, None, node2, properties={"a": 10})
diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py
new file mode 100644
index 0000000..faf8ab6
--- /dev/null
+++ b/tests/test_graph_utils/test_node.py
@@ -0,0 +1,52 @@
+import pytest
+
+from redis.commands.graph import node
+
+
+@pytest.fixture
+def fixture():
+ no_args = node.Node()
+ no_props = node.Node(node_id=1, alias="alias", label="l")
+ props_only = node.Node(properties={"a": "a", "b": 10})
+ no_label = node.Node(node_id=1, alias="alias", properties={"a": "a"})
+ multi_label = node.Node(node_id=1, alias="alias", label=["l", "ll"])
+ return no_args, no_props, props_only, no_label, multi_label
+
+
+@pytest.mark.redismod
+def test_toString(fixture):
+ no_args, no_props, props_only, no_label, multi_label = fixture
+ assert no_args.toString() == ""
+ assert no_props.toString() == ""
+ assert props_only.toString() == '{a:"a",b:10}'
+ assert no_label.toString() == '{a:"a"}'
+ assert multi_label.toString() == ""
+
+
+@pytest.mark.redismod
+def test_stringify(fixture):
+ no_args, no_props, props_only, no_label, multi_label = fixture
+ assert str(no_args) == "()"
+ assert str(no_props) == "(alias:l)"
+ assert str(props_only) == '({a:"a",b:10})'
+ assert str(no_label) == '(alias{a:"a"})'
+ assert str(multi_label) == "(alias:l:ll)"
+
+
+@pytest.mark.redismod
+def test_comparision(fixture):
+ no_args, no_props, props_only, no_label, multi_label = fixture
+
+ assert node.Node() == node.Node()
+ assert node.Node(node_id=1) == node.Node(node_id=1)
+ assert node.Node(node_id=1) != node.Node(node_id=2)
+ assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="b")
+ assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="a")
+ assert node.Node(node_id=1, label="a") == node.Node(node_id=1, label="a")
+ assert node.Node(node_id=1, label="a") != node.Node(node_id=1, label="b")
+ assert node.Node(node_id=1, alias="a", label="l") == node.Node(
+ node_id=1, alias="a", label="l"
+ )
+ assert node.Node(alias="a", label="l") != node.Node(alias="a", label="l1")
+ assert node.Node(properties={"a": 10}) == node.Node(properties={"a": 10})
+ assert node.Node() != node.Node(properties={"a": 10})
diff --git a/tests/test_graph_utils/test_path.py b/tests/test_graph_utils/test_path.py
new file mode 100644
index 0000000..d581269
--- /dev/null
+++ b/tests/test_graph_utils/test_path.py
@@ -0,0 +1,91 @@
+import pytest
+
+from redis.commands.graph import edge, node, path
+
+
+@pytest.mark.redismod
+def test_init():
+ with pytest.raises(TypeError):
+ path.Path(None, None)
+ path.Path([], None)
+ path.Path(None, [])
+
+ assert isinstance(path.Path([], []), path.Path)
+
+
+@pytest.mark.redismod
+def test_new_empty_path():
+ new_empty_path = path.Path.new_empty_path()
+ assert isinstance(new_empty_path, path.Path)
+ assert new_empty_path._nodes == []
+ assert new_empty_path._edges == []
+
+
+@pytest.mark.redismod
+def test_wrong_flows():
+ node_1 = node.Node(node_id=1)
+ node_2 = node.Node(node_id=2)
+ node_3 = node.Node(node_id=3)
+
+ edge_1 = edge.Edge(node_1, None, node_2)
+ edge_2 = edge.Edge(node_1, None, node_3)
+
+ p = path.Path.new_empty_path()
+ with pytest.raises(AssertionError):
+ p.add_edge(edge_1)
+
+ p.add_node(node_1)
+ with pytest.raises(AssertionError):
+ p.add_node(node_2)
+
+ p.add_edge(edge_1)
+ with pytest.raises(AssertionError):
+ p.add_edge(edge_2)
+
+
+@pytest.mark.redismod
+def test_nodes_and_edges():
+ node_1 = node.Node(node_id=1)
+ node_2 = node.Node(node_id=2)
+ edge_1 = edge.Edge(node_1, None, node_2)
+
+ p = path.Path.new_empty_path()
+ assert p.nodes() == []
+ p.add_node(node_1)
+ assert [] == p.edges()
+ assert 0 == p.edge_count()
+ assert [node_1] == p.nodes()
+ assert node_1 == p.get_node(0)
+ assert node_1 == p.first_node()
+ assert node_1 == p.last_node()
+ assert 1 == p.nodes_count()
+ p.add_edge(edge_1)
+ assert [edge_1] == p.edges()
+ assert 1 == p.edge_count()
+ assert edge_1 == p.get_relationship(0)
+ p.add_node(node_2)
+ assert [node_1, node_2] == p.nodes()
+ assert node_1 == p.first_node()
+ assert node_2 == p.last_node()
+ assert 2 == p.nodes_count()
+
+
+@pytest.mark.redismod
+def test_compare():
+ node_1 = node.Node(node_id=1)
+ node_2 = node.Node(node_id=2)
+ edge_1 = edge.Edge(node_1, None, node_2)
+
+ assert path.Path.new_empty_path() == path.Path.new_empty_path()
+ assert path.Path(nodes=[node_1, node_2], edges=[edge_1]) == path.Path(
+ nodes=[node_1, node_2], edges=[edge_1]
+ )
+ assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[])
+ assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[])
+ assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[node_2], edges=[])
+ assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path(
+ nodes=[node_1], edges=[]
+ )
+ assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path(
+ nodes=[node_2], edges=[edge_1]
+ )
diff --git a/tox.ini b/tox.ini
index 9d78e2a..0ccc9bb 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,13 +2,14 @@
addopts = -s
markers =
redismod: run only the redis module tests
+ pipeline: pipeline tests
onlycluster: marks tests to be run only with cluster mode redis
onlynoncluster: marks tests to be run only with standalone redis
[tox]
minversion = 3.2.0
requires = tox-docker
-envlist = {standalone,cluster}-{plain,hiredis}-{py35,py36,py37,py38,py39,pypy3},linters,docs
+envlist = {standalone,cluster}-{plain,hiredis}-{py36,py37,py38,py39,pypy3},linters,docs
[docker:master]
name = master