summaryrefslogtreecommitdiff
path: root/redis/commands/graph/query_result.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/commands/graph/query_result.py')
-rw-r--r--redis/commands/graph/query_result.py362
1 files changed, 362 insertions, 0 deletions
diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py
new file mode 100644
index 0000000..e9d9f4d
--- /dev/null
+++ b/redis/commands/graph/query_result.py
@@ -0,0 +1,362 @@
+from collections import OrderedDict
+
+# from prettytable import PrettyTable
+from redis import ResponseError
+
+from .edge import Edge
+from .exceptions import VersionMismatchException
+from .node import Node
+from .path import Path
+
+LABELS_ADDED = "Labels added"
+NODES_CREATED = "Nodes created"
+NODES_DELETED = "Nodes deleted"
+RELATIONSHIPS_DELETED = "Relationships deleted"
+PROPERTIES_SET = "Properties set"
+RELATIONSHIPS_CREATED = "Relationships created"
+INDICES_CREATED = "Indices created"
+INDICES_DELETED = "Indices deleted"
+CACHED_EXECUTION = "Cached execution"
+INTERNAL_EXECUTION_TIME = "internal execution time"
+
+STATS = [
+ LABELS_ADDED,
+ NODES_CREATED,
+ PROPERTIES_SET,
+ RELATIONSHIPS_CREATED,
+ NODES_DELETED,
+ RELATIONSHIPS_DELETED,
+ INDICES_CREATED,
+ INDICES_DELETED,
+ CACHED_EXECUTION,
+ INTERNAL_EXECUTION_TIME,
+]
+
+
+class ResultSetColumnTypes:
+ COLUMN_UNKNOWN = 0
+ COLUMN_SCALAR = 1
+ COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
+ COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
+
+
+class ResultSetScalarTypes:
+ VALUE_UNKNOWN = 0
+ VALUE_NULL = 1
+ VALUE_STRING = 2
+ VALUE_INTEGER = 3
+ VALUE_BOOLEAN = 4
+ VALUE_DOUBLE = 5
+ VALUE_ARRAY = 6
+ VALUE_EDGE = 7
+ VALUE_NODE = 8
+ VALUE_PATH = 9
+ VALUE_MAP = 10
+ VALUE_POINT = 11
+
+
+class QueryResult:
+ def __init__(self, graph, response, profile=False):
+ """
+ A class that represents a result of the query operation.
+
+ Args:
+
+ graph:
+ The graph on which the query was executed.
+ response:
+ The response from the server.
+ profile:
+ A boolean indicating if the query command was "GRAPH.PROFILE"
+ """
+ self.graph = graph
+ self.header = []
+ self.result_set = []
+
+ # in case of an error an exception will be raised
+ self._check_for_errors(response)
+
+ if len(response) == 1:
+ self.parse_statistics(response[0])
+ elif profile:
+ self.parse_profile(response)
+ else:
+ # start by parsing statistics, matches the one we have
+ self.parse_statistics(response[-1]) # Last element.
+ self.parse_results(response)
+
+ def _check_for_errors(self, response):
+ if isinstance(response[0], ResponseError):
+ error = response[0]
+ if str(error) == "version mismatch":
+ version = response[1]
+ error = VersionMismatchException(version)
+ raise error
+
+ # If we encountered a run-time error, the last response
+ # element will be an exception
+ if isinstance(response[-1], ResponseError):
+ raise response[-1]
+
+ def parse_results(self, raw_result_set):
+ self.header = self.parse_header(raw_result_set)
+
+ # Empty header.
+ if len(self.header) == 0:
+ return
+
+ self.result_set = self.parse_records(raw_result_set)
+
+ def parse_statistics(self, raw_statistics):
+ self.statistics = {}
+
+ # decode statistics
+ for idx, stat in enumerate(raw_statistics):
+ if isinstance(stat, bytes):
+ raw_statistics[idx] = stat.decode()
+
+ for s in STATS:
+ v = self._get_value(s, raw_statistics)
+ if v is not None:
+ self.statistics[s] = v
+
+ def parse_header(self, raw_result_set):
+ # An array of column name/column type pairs.
+ header = raw_result_set[0]
+ return header
+
+ def parse_records(self, raw_result_set):
+ records = []
+ result_set = raw_result_set[1]
+ for row in result_set:
+ record = []
+ for idx, cell in enumerate(row):
+ if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa
+ record.append(self.parse_scalar(cell))
+ elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa
+ record.append(self.parse_node(cell))
+ elif (
+ self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION
+ ): # noqa
+ record.append(self.parse_edge(cell))
+ else:
+ print("Unknown column type.\n")
+ records.append(record)
+
+ return records
+
+ def parse_entity_properties(self, props):
+ # [[name, value type, value] X N]
+ properties = {}
+ for prop in props:
+ prop_name = self.graph.get_property(prop[0])
+ prop_value = self.parse_scalar(prop[1:])
+ properties[prop_name] = prop_value
+
+ return properties
+
+ def parse_string(self, cell):
+ if isinstance(cell, bytes):
+ return cell.decode()
+ elif not isinstance(cell, str):
+ return str(cell)
+ else:
+ return cell
+
+ def parse_node(self, cell):
+ # Node ID (integer),
+ # [label string offset (integer)],
+ # [[name, value type, value] X N]
+
+ node_id = int(cell[0])
+ labels = None
+ if len(cell[1]) > 0:
+ labels = []
+ for inner_label in cell[1]:
+ labels.append(self.graph.get_label(inner_label))
+ properties = self.parse_entity_properties(cell[2])
+ return Node(node_id=node_id, label=labels, properties=properties)
+
+ def parse_edge(self, cell):
+ # Edge ID (integer),
+ # reltype string offset (integer),
+ # src node ID offset (integer),
+ # dest node ID offset (integer),
+ # [[name, value, value type] X N]
+
+ edge_id = int(cell[0])
+ relation = self.graph.get_relation(cell[1])
+ src_node_id = int(cell[2])
+ dest_node_id = int(cell[3])
+ properties = self.parse_entity_properties(cell[4])
+ return Edge(
+ src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
+ )
+
+ def parse_path(self, cell):
+ nodes = self.parse_scalar(cell[0])
+ edges = self.parse_scalar(cell[1])
+ return Path(nodes, edges)
+
+ def parse_map(self, cell):
+ m = OrderedDict()
+ n_entries = len(cell)
+
+ # A map is an array of key value pairs.
+ # 1. key (string)
+ # 2. array: (value type, value)
+ for i in range(0, n_entries, 2):
+ key = self.parse_string(cell[i])
+ m[key] = self.parse_scalar(cell[i + 1])
+
+ return m
+
+ def parse_point(self, cell):
+ p = {}
+ # A point is received an array of the form: [latitude, longitude]
+ # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa
+ p["latitude"] = float(cell[0])
+ p["longitude"] = float(cell[1])
+ return p
+
+ def parse_scalar(self, cell):
+ scalar_type = int(cell[0])
+ value = cell[1]
+ scalar = None
+
+ if scalar_type == ResultSetScalarTypes.VALUE_NULL:
+ scalar = None
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_STRING:
+ scalar = self.parse_string(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER:
+ scalar = int(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN:
+ value = value.decode() if isinstance(value, bytes) else value
+ if value == "true":
+ scalar = True
+ elif value == "false":
+ scalar = False
+ else:
+ print("Unknown boolean type\n")
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE:
+ scalar = float(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY:
+ # array variable is introduced only for readability
+ scalar = array = value
+ for i in range(len(array)):
+ scalar[i] = self.parse_scalar(array[i])
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_NODE:
+ scalar = self.parse_node(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_EDGE:
+ scalar = self.parse_edge(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_PATH:
+ scalar = self.parse_path(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_MAP:
+ scalar = self.parse_map(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_POINT:
+ scalar = self.parse_point(value)
+
+ elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN:
+ print("Unknown scalar type\n")
+
+ return scalar
+
+ def parse_profile(self, response):
+ self.result_set = [x[0 : x.index(",")].strip() for x in response]
+
+ # """Prints the data from the query response:
+ # 1. First row result_set contains the columns names.
+ # Thus the first row in PrettyTable will contain the
+ # columns.
+ # 2. The row after that will contain the data returned,
+ # or 'No Data returned' if there is none.
+ # 3. Prints the statistics of the query.
+ # """
+
+ # def pretty_print(self):
+ # if not self.is_empty():
+ # header = [col[1] for col in self.header]
+ # tbl = PrettyTable(header)
+
+ # for row in self.result_set:
+ # record = []
+ # for idx, cell in enumerate(row):
+ # if type(cell) is Node:
+ # record.append(cell.toString())
+ # elif type(cell) is Edge:
+ # record.append(cell.toString())
+ # else:
+ # record.append(cell)
+ # tbl.add_row(record)
+
+ # if len(self.result_set) == 0:
+ # tbl.add_row(['No data returned.'])
+
+ # print(str(tbl) + '\n')
+
+ # for stat in self.statistics:
+ # print("%s %s" % (stat, self.statistics[stat]))
+
+ def is_empty(self):
+ return len(self.result_set) == 0
+
+ @staticmethod
+ def _get_value(prop, statistics):
+ for stat in statistics:
+ if prop in stat:
+ return float(stat.split(": ")[1].split(" ")[0])
+
+ return None
+
+ def _get_stat(self, stat):
+ return self.statistics[stat] if stat in self.statistics else 0
+
+ @property
+ def labels_added(self):
+ return self._get_stat(LABELS_ADDED)
+
+ @property
+ def nodes_created(self):
+ return self._get_stat(NODES_CREATED)
+
+ @property
+ def nodes_deleted(self):
+ return self._get_stat(NODES_DELETED)
+
+ @property
+ def properties_set(self):
+ return self._get_stat(PROPERTIES_SET)
+
+ @property
+ def relationships_created(self):
+ return self._get_stat(RELATIONSHIPS_CREATED)
+
+ @property
+ def relationships_deleted(self):
+ return self._get_stat(RELATIONSHIPS_DELETED)
+
+ @property
+ def indices_created(self):
+ return self._get_stat(INDICES_CREATED)
+
+ @property
+ def indices_deleted(self):
+ return self._get_stat(INDICES_DELETED)
+
+ @property
+ def cached_execution(self):
+ return self._get_stat(CACHED_EXECUTION) == 1
+
+ @property
+ def run_time_ms(self):
+ return self._get_stat(INTERNAL_EXECUTION_TIME)