diff options
Diffstat (limited to 'redis/commands/graph/query_result.py')
-rw-r--r-- | redis/commands/graph/query_result.py | 371 |
1 files changed, 279 insertions, 92 deletions
diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py index 3ffa664..b88b4b6 100644 --- a/redis/commands/graph/query_result.py +++ b/redis/commands/graph/query_result.py @@ -1,4 +1,6 @@ +import sys from collections import OrderedDict +from distutils.util import strtobool # from prettytable import PrettyTable from redis import ResponseError @@ -90,6 +92,9 @@ class QueryResult: self.parse_results(response) def _check_for_errors(self, response): + """ + Check if the response contains an error. + """ if isinstance(response[0], ResponseError): error = response[0] if str(error) == "version mismatch": @@ -103,6 +108,9 @@ class QueryResult: raise response[-1] def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ self.header = self.parse_header(raw_result_set) # Empty header. @@ -112,6 +120,9 @@ class QueryResult: self.result_set = self.parse_records(raw_result_set) def parse_statistics(self, raw_statistics): + """ + Parse the statistics returned in the response. + """ self.statistics = {} # decode statistics @@ -125,31 +136,31 @@ class QueryResult: self.statistics[s] = v def parse_header(self, raw_result_set): + """ + Parse the header of the result. + """ # An array of column name/column type pairs. header = raw_result_set[0] return header def parse_records(self, raw_result_set): - records = [] - result_set = raw_result_set[1] - for row in result_set: - record = [] - for idx, cell in enumerate(row): - if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa - record.append(self.parse_scalar(cell)) - elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa - record.append(self.parse_node(cell)) - elif ( - self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION - ): # noqa - record.append(self.parse_edge(cell)) - else: - print("Unknown column type.\n") - records.append(record) + """ + Parses the result set and returns a list of records. + """ + records = [ + [ + self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + for row in raw_result_set[1] + ] return records def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ # [[name, value type, value] X N] properties = {} for prop in props: @@ -160,6 +171,9 @@ class QueryResult: return properties def parse_string(self, cell): + """ + Parse the cell as a string. + """ if isinstance(cell, bytes): return cell.decode() elif not isinstance(cell, str): @@ -168,6 +182,9 @@ class QueryResult: return cell def parse_node(self, cell): + """ + Parse the cell to a node. + """ # Node ID (integer), # [label string offset (integer)], # [[name, value type, value] X N] @@ -182,6 +199,9 @@ class QueryResult: return Node(node_id=node_id, label=labels, properties=properties) def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ # Edge ID (integer), # reltype string offset (integer), # src node ID offset (integer), @@ -198,11 +218,17 @@ class QueryResult: ) def parse_path(self, cell): + """ + Parse the cell to a path. + """ nodes = self.parse_scalar(cell[0]) edges = self.parse_scalar(cell[1]) return Path(nodes, edges) def parse_map(self, cell): + """ + Parse the cell as a map. + """ m = OrderedDict() n_entries = len(cell) @@ -216,6 +242,9 @@ class QueryResult: return m def parse_point(self, cell): + """ + Parse the cell to point. + """ p = {} # A point is received an array of the form: [latitude, longitude] # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa @@ -223,94 +252,63 @@ class QueryResult: p["longitude"] = float(cell[1]) return p - def parse_scalar(self, cell): - scalar_type = int(cell[0]) - value = cell[1] - scalar = None - - if scalar_type == ResultSetScalarTypes.VALUE_NULL: - scalar = None - - elif scalar_type == ResultSetScalarTypes.VALUE_STRING: - scalar = self.parse_string(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER: - scalar = int(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN: - value = value.decode() if isinstance(value, bytes) else value - if value == "true": - scalar = True - elif value == "false": - scalar = False - else: - print("Unknown boolean type\n") - - elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE: - scalar = float(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY: - # array variable is introduced only for readability - scalar = array = value - for i in range(len(array)): - scalar[i] = self.parse_scalar(array[i]) + def parse_null(self, cell): + """ + Parse a null value. + """ + return None - elif scalar_type == ResultSetScalarTypes.VALUE_NODE: - scalar = self.parse_node(value) + def parse_integer(self, cell): + """ + Parse the integer value from the cell. + """ + return int(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_EDGE: - scalar = self.parse_edge(value) + def parse_boolean(self, value): + """ + Parse the cell value as a boolean. + """ + value = value.decode() if isinstance(value, bytes) else value + try: + scalar = strtobool(value) + except ValueError: + sys.stderr.write("unknown boolean type\n") + scalar = None + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_PATH: - scalar = self.parse_path(value) + def parse_double(self, cell): + """ + Parse the cell as a double. + """ + return float(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_MAP: - scalar = self.parse_map(value) + def parse_array(self, value): + """ + Parse an array of values. + """ + scalar = [self.parse_scalar(value[i]) for i in range(len(value))] + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_POINT: - scalar = self.parse_point(value) + def parse_unknown(self, cell): + """ + Parse a cell of unknown type. + """ + sys.stderr.write("Unknown type\n") + return None - elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN: - print("Unknown scalar type\n") + def parse_scalar(self, cell): + """ + Parse a scalar value from a cell in the result set. + """ + scalar_type = int(cell[0]) + value = cell[1] + scalar = self.parse_scalar_types[scalar_type](value) return scalar def parse_profile(self, response): self.result_set = [x[0 : x.index(",")].strip() for x in response] - # """Prints the data from the query response: - # 1. First row result_set contains the columns names. - # Thus the first row in PrettyTable will contain the - # columns. - # 2. The row after that will contain the data returned, - # or 'No Data returned' if there is none. - # 3. Prints the statistics of the query. - # """ - - # def pretty_print(self): - # if not self.is_empty(): - # header = [col[1] for col in self.header] - # tbl = PrettyTable(header) - - # for row in self.result_set: - # record = [] - # for idx, cell in enumerate(row): - # if type(cell) is Node: - # record.append(cell.to_string()) - # elif type(cell) is Edge: - # record.append(cell.to_string()) - # else: - # record.append(cell) - # tbl.add_row(record) - - # if len(self.result_set) == 0: - # tbl.add_row(['No data returned.']) - - # print(str(tbl) + '\n') - - # for stat in self.statistics: - # print("%s %s" % (stat, self.statistics[stat])) - def is_empty(self): return len(self.result_set) == 0 @@ -384,3 +382,192 @@ class QueryResult: def run_time_ms(self): """Returns the server execution time of the query""" return self._get_stat(INTERNAL_EXECUTION_TIME) + + @property + def parse_scalar_types(self): + return { + ResultSetScalarTypes.VALUE_NULL: self.parse_null, + ResultSetScalarTypes.VALUE_STRING: self.parse_string, + ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, + ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, + ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, + ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, + ResultSetScalarTypes.VALUE_NODE: self.parse_node, + ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, + ResultSetScalarTypes.VALUE_PATH: self.parse_path, + ResultSetScalarTypes.VALUE_MAP: self.parse_map, + ResultSetScalarTypes.VALUE_POINT: self.parse_point, + ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, + } + + @property + def parse_record_types(self): + return { + ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, + ResultSetColumnTypes.COLUMN_NODE: self.parse_node, + ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, + ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, + } + + +class AsyncQueryResult(QueryResult): + """ + Async version for the QueryResult class - a class that + represents a result of the query operation. + """ + + def __init__(self): + """ + To init the class you must call self.initialize() + """ + pass + + async def initialize(self, graph, response, profile=False): + """ + Initializes the class. + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + await self.parse_results(response) + + return self + + async def parse_node(self, cell): + """ + Parses a node from the cell. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(await self.graph.get_label(inner_label)) + properties = await self.parse_entity_properties(cell[2]) + node_id = int(cell[0]) + return Node(node_id=node_id, label=labels, properties=properties) + + async def parse_scalar(self, cell): + """ + Parses a scalar value from the server response. + """ + scalar_type = int(cell[0]) + value = cell[1] + try: + scalar = await self.parse_scalar_types[scalar_type](value) + except TypeError: + # Not all of the functions are async + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + async def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [] + for row in raw_result_set[1]: + record = [ + await self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + records.append(record) + + return records + + async def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = await self.parse_records(raw_result_set) + + async def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = await self.graph.get_property(prop[0]) + prop_value = await self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + async def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = await self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = await self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + async def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = await self.parse_scalar(cell[0]) + edges = await self.parse_scalar(cell[1]) + return Path(nodes, edges) + + async def parse_map(self, cell): + """ + Parse the cell to a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = await self.parse_scalar(cell[i + 1]) + + return m + + async def parse_array(self, value): + """ + Parse array value. + """ + scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] + return scalar |