summaryrefslogtreecommitdiff
path: root/redis/commands/search/querystring.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/commands/search/querystring.py')
-rw-r--r--redis/commands/search/querystring.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/redis/commands/search/querystring.py b/redis/commands/search/querystring.py
new file mode 100644
index 0000000..f5f59b7
--- /dev/null
+++ b/redis/commands/search/querystring.py
@@ -0,0 +1,324 @@
+from six import string_types, integer_types
+
+
+def tags(*t):
+ """
+ Indicate that the values should be matched to a tag field
+
+ ### Parameters
+
+ - **t**: Tags to search for
+ """
+ if not t:
+ raise ValueError("At least one tag must be specified")
+ return TagValue(*t)
+
+
+def between(a, b, inclusive_min=True, inclusive_max=True):
+ """
+ Indicate that value is a numeric range
+ """
+ return RangeValue(a, b, inclusive_min=inclusive_min,
+ inclusive_max=inclusive_max)
+
+
+def equal(n):
+ """
+ Match a numeric value
+ """
+ return between(n, n)
+
+
+def lt(n):
+ """
+ Match any value less than n
+ """
+ return between(None, n, inclusive_max=False)
+
+
+def le(n):
+ """
+ Match any value less or equal to n
+ """
+ return between(None, n, inclusive_max=True)
+
+
+def gt(n):
+ """
+ Match any value greater than n
+ """
+ return between(n, None, inclusive_min=False)
+
+
+def ge(n):
+ """
+ Match any value greater or equal to n
+ """
+ return between(n, None, inclusive_min=True)
+
+
+def geo(lat, lon, radius, unit="km"):
+ """
+ Indicate that value is a geo region
+ """
+ return GeoValue(lat, lon, radius, unit)
+
+
+class Value(object):
+ @property
+ def combinable(self):
+ """
+ Whether this type of value may be combined with other values
+ for the same field. This makes the filter potentially more efficient
+ """
+ return False
+
+ @staticmethod
+ def make_value(v):
+ """
+ Convert an object to a value, if it is not a value already
+ """
+ if isinstance(v, Value):
+ return v
+ return ScalarValue(v)
+
+ def to_string(self):
+ raise NotImplementedError()
+
+ def __str__(self):
+ return self.to_string()
+
+
+class RangeValue(Value):
+ combinable = False
+
+ def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
+ if a is None:
+ a = "-inf"
+ if b is None:
+ b = "inf"
+ self.range = [str(a), str(b)]
+ self.inclusive_min = inclusive_min
+ self.inclusive_max = inclusive_max
+
+ def to_string(self):
+ return "[{1}{0[0]} {2}{0[1]}]".format(
+ self.range,
+ "(" if not self.inclusive_min else "",
+ "(" if not self.inclusive_max else "",
+ )
+
+
+class ScalarValue(Value):
+ combinable = True
+
+ def __init__(self, v):
+ self.v = str(v)
+
+ def to_string(self):
+ return self.v
+
+
+class TagValue(Value):
+ combinable = False
+
+ def __init__(self, *tags):
+ self.tags = tags
+
+ def to_string(self):
+ return "{" + " | ".join(str(t) for t in self.tags) + "}"
+
+
+class GeoValue(Value):
+ def __init__(self, lon, lat, radius, unit="km"):
+ self.lon = lon
+ self.lat = lat
+ self.radius = radius
+ self.unit = unit
+
+
+class Node(object):
+ def __init__(self, *children, **kwparams):
+ """
+ Create a node
+
+ ### Parameters
+
+ - **children**: One or more sub-conditions. These can be additional
+ `intersect`, `disjunct`, `union`, `optional`, or any other `Node`
+ type.
+
+ The semantics of multiple conditions are dependent on the type of
+ query. For an `intersection` node, this amounts to a logical AND,
+ for a `union` node, this amounts to a logical `OR`.
+
+ - **kwparams**: key-value parameters. Each key is the name of a field,
+ and the value should be a field value. This can be one of the
+ following:
+
+ - Simple string (for text field matches)
+ - value returned by one of the helper functions
+ - list of either a string or a value
+
+
+ ### Examples
+
+ Field `num` should be between 1 and 10
+ ```
+ intersect(num=between(1, 10)
+ ```
+
+ Name can either be `bob` or `john`
+
+ ```
+ union(name=("bob", "john"))
+ ```
+
+ Don't select countries in Israel, Japan, or US
+
+ ```
+ disjunct_union(country=("il", "jp", "us"))
+ ```
+ """
+
+ self.params = []
+
+ kvparams = {}
+ for k, v in kwparams.items():
+ curvals = kvparams.setdefault(k, [])
+ if isinstance(v, (string_types, integer_types, float)):
+ curvals.append(Value.make_value(v))
+ elif isinstance(v, Value):
+ curvals.append(v)
+ else:
+ curvals.extend(Value.make_value(subv) for subv in v)
+
+ self.params += [Node.to_node(p) for p in children]
+
+ for k, v in kvparams.items():
+ self.params.extend(self.join_fields(k, v))
+
+ def join_fields(self, key, vals):
+ if len(vals) == 1:
+ return [BaseNode("@{}:{}".format(key, vals[0].to_string()))]
+ if not vals[0].combinable:
+ return [BaseNode("@{}:{}".format(key,
+ v.to_string())) for v in vals]
+ s = BaseNode(
+ "@{}:({})".format(key,
+ self.JOINSTR.join(v.to_string() for v in vals))
+ )
+ return [s]
+
+ @classmethod
+ def to_node(cls, obj): # noqa
+ if isinstance(obj, Node):
+ return obj
+ return BaseNode(obj)
+
+ @property
+ def JOINSTR(self):
+ raise NotImplementedError()
+
+ def to_string(self, with_parens=None):
+ with_parens = self._should_use_paren(with_parens)
+ pre, post = ("(", ")") if with_parens else ("", "")
+ return "{}{}{}".format(
+ pre, self.JOINSTR.join(n.to_string() for n in self.params), post
+ )
+
+ def _should_use_paren(self, optval):
+ if optval is not None:
+ return optval
+ return len(self.params) > 1
+
+ def __str__(self):
+ return self.to_string()
+
+
+class BaseNode(Node):
+ def __init__(self, s):
+ super(BaseNode, self).__init__()
+ self.s = str(s)
+
+ def to_string(self, with_parens=None):
+ return self.s
+
+
+class IntersectNode(Node):
+ """
+ Create an intersection node. All children need to be satisfied in order for
+ this node to evaluate as true
+ """
+
+ JOINSTR = " "
+
+
+class UnionNode(Node):
+ """
+ Create a union node. Any of the children need to be satisfied in order for
+ this node to evaluate as true
+ """
+
+ JOINSTR = "|"
+
+
+class DisjunctNode(IntersectNode):
+ """
+ Create a disjunct node. In order for this node to be true, all of its
+ children must evaluate to false
+ """
+
+ def to_string(self, with_parens=None):
+ with_parens = self._should_use_paren(with_parens)
+ ret = super(DisjunctNode, self).to_string(with_parens=False)
+ if with_parens:
+ return "(-" + ret + ")"
+ else:
+ return "-" + ret
+
+
+class DistjunctUnion(DisjunctNode):
+ """
+ This node is true if *all* of its children are false. This is equivalent to
+ ```
+ disjunct(union(...))
+ ```
+ """
+
+ JOINSTR = "|"
+
+
+class OptionalNode(IntersectNode):
+ """
+ Create an optional node. If this nodes evaluates to true, then the document
+ will be rated higher in score/rank.
+ """
+
+ def to_string(self, with_parens=None):
+ with_parens = self._should_use_paren(with_parens)
+ ret = super(OptionalNode, self).to_string(with_parens=False)
+ if with_parens:
+ return "(~" + ret + ")"
+ else:
+ return "~" + ret
+
+
+def intersect(*args, **kwargs):
+ return IntersectNode(*args, **kwargs)
+
+
+def union(*args, **kwargs):
+ return UnionNode(*args, **kwargs)
+
+
+def disjunct(*args, **kwargs):
+ return DisjunctNode(*args, **kwargs)
+
+
+def disjunct_union(*args, **kwargs):
+ return DistjunctUnion(*args, **kwargs)
+
+
+def querystring(*args, **kwargs):
+ return intersect(*args, **kwargs).to_string()