summaryrefslogtreecommitdiff
path: root/redis/commands/search/querystring.py
blob: f5f59b7e534571322904ee441608e528f2c26028 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from six import string_types, integer_types


def tags(*t):
    """
    Indicate that the values should be matched to a tag field

    ### Parameters

    - **t**: Tags to search for
    """
    if not t:
        raise ValueError("At least one tag must be specified")
    return TagValue(*t)


def between(a, b, inclusive_min=True, inclusive_max=True):
    """
    Indicate that value is a numeric range
    """
    return RangeValue(a, b, inclusive_min=inclusive_min,
                      inclusive_max=inclusive_max)


def equal(n):
    """
    Match a numeric value
    """
    return between(n, n)


def lt(n):
    """
    Match any value less than n
    """
    return between(None, n, inclusive_max=False)


def le(n):
    """
    Match any value less or equal to n
    """
    return between(None, n, inclusive_max=True)


def gt(n):
    """
    Match any value greater than n
    """
    return between(n, None, inclusive_min=False)


def ge(n):
    """
    Match any value greater or equal to n
    """
    return between(n, None, inclusive_min=True)


def geo(lat, lon, radius, unit="km"):
    """
    Indicate that value is a geo region
    """
    return GeoValue(lat, lon, radius, unit)


class Value(object):
    @property
    def combinable(self):
        """
        Whether this type of value may be combined with other values
        for the same field. This makes the filter potentially more efficient
        """
        return False

    @staticmethod
    def make_value(v):
        """
        Convert an object to a value, if it is not a value already
        """
        if isinstance(v, Value):
            return v
        return ScalarValue(v)

    def to_string(self):
        raise NotImplementedError()

    def __str__(self):
        return self.to_string()


class RangeValue(Value):
    combinable = False

    def __init__(self, a, b, inclusive_min=False, inclusive_max=False):
        if a is None:
            a = "-inf"
        if b is None:
            b = "inf"
        self.range = [str(a), str(b)]
        self.inclusive_min = inclusive_min
        self.inclusive_max = inclusive_max

    def to_string(self):
        return "[{1}{0[0]} {2}{0[1]}]".format(
            self.range,
            "(" if not self.inclusive_min else "",
            "(" if not self.inclusive_max else "",
        )


class ScalarValue(Value):
    combinable = True

    def __init__(self, v):
        self.v = str(v)

    def to_string(self):
        return self.v


class TagValue(Value):
    combinable = False

    def __init__(self, *tags):
        self.tags = tags

    def to_string(self):
        return "{" + " | ".join(str(t) for t in self.tags) + "}"


class GeoValue(Value):
    def __init__(self, lon, lat, radius, unit="km"):
        self.lon = lon
        self.lat = lat
        self.radius = radius
        self.unit = unit


class Node(object):
    def __init__(self, *children, **kwparams):
        """
        Create a node

        ### Parameters

        - **children**: One or more sub-conditions. These can be additional
            `intersect`, `disjunct`, `union`, `optional`, or any other `Node`
            type.

            The semantics of multiple conditions are dependent on the type of
            query. For an `intersection` node, this amounts to a logical AND,
            for a `union` node, this amounts to a logical `OR`.

        - **kwparams**: key-value parameters. Each key is the name of a field,
            and the value should be a field value. This can be one of the
            following:

            - Simple string (for text field matches)
            - value returned by one of the helper functions
            - list of either a string or a value


        ### Examples

        Field `num` should be between 1 and 10
        ```
        intersect(num=between(1, 10)
        ```

        Name can either be `bob` or `john`

        ```
        union(name=("bob", "john"))
        ```

        Don't select countries in Israel, Japan, or US

        ```
        disjunct_union(country=("il", "jp", "us"))
        ```
        """

        self.params = []

        kvparams = {}
        for k, v in kwparams.items():
            curvals = kvparams.setdefault(k, [])
            if isinstance(v, (string_types, integer_types, float)):
                curvals.append(Value.make_value(v))
            elif isinstance(v, Value):
                curvals.append(v)
            else:
                curvals.extend(Value.make_value(subv) for subv in v)

        self.params += [Node.to_node(p) for p in children]

        for k, v in kvparams.items():
            self.params.extend(self.join_fields(k, v))

    def join_fields(self, key, vals):
        if len(vals) == 1:
            return [BaseNode("@{}:{}".format(key, vals[0].to_string()))]
        if not vals[0].combinable:
            return [BaseNode("@{}:{}".format(key,
                                             v.to_string())) for v in vals]
        s = BaseNode(
            "@{}:({})".format(key,
                              self.JOINSTR.join(v.to_string() for v in vals))
        )
        return [s]

    @classmethod
    def to_node(cls, obj):  # noqa
        if isinstance(obj, Node):
            return obj
        return BaseNode(obj)

    @property
    def JOINSTR(self):
        raise NotImplementedError()

    def to_string(self, with_parens=None):
        with_parens = self._should_use_paren(with_parens)
        pre, post = ("(", ")") if with_parens else ("", "")
        return "{}{}{}".format(
            pre, self.JOINSTR.join(n.to_string() for n in self.params), post
        )

    def _should_use_paren(self, optval):
        if optval is not None:
            return optval
        return len(self.params) > 1

    def __str__(self):
        return self.to_string()


class BaseNode(Node):
    def __init__(self, s):
        super(BaseNode, self).__init__()
        self.s = str(s)

    def to_string(self, with_parens=None):
        return self.s


class IntersectNode(Node):
    """
    Create an intersection node. All children need to be satisfied in order for
    this node to evaluate as true
    """

    JOINSTR = " "


class UnionNode(Node):
    """
    Create a union node. Any of the children need to be satisfied in order for
    this node to evaluate as true
    """

    JOINSTR = "|"


class DisjunctNode(IntersectNode):
    """
    Create a disjunct node. In order for this node to be true, all of its
    children must evaluate to false
    """

    def to_string(self, with_parens=None):
        with_parens = self._should_use_paren(with_parens)
        ret = super(DisjunctNode, self).to_string(with_parens=False)
        if with_parens:
            return "(-" + ret + ")"
        else:
            return "-" + ret


class DistjunctUnion(DisjunctNode):
    """
    This node is true if *all* of its children are false. This is equivalent to
    ```
    disjunct(union(...))
    ```
    """

    JOINSTR = "|"


class OptionalNode(IntersectNode):
    """
    Create an optional node. If this nodes evaluates to true, then the document
    will be rated higher in score/rank.
    """

    def to_string(self, with_parens=None):
        with_parens = self._should_use_paren(with_parens)
        ret = super(OptionalNode, self).to_string(with_parens=False)
        if with_parens:
            return "(~" + ret + ")"
        else:
            return "~" + ret


def intersect(*args, **kwargs):
    return IntersectNode(*args, **kwargs)


def union(*args, **kwargs):
    return UnionNode(*args, **kwargs)


def disjunct(*args, **kwargs):
    return DisjunctNode(*args, **kwargs)


def disjunct_union(*args, **kwargs):
    return DistjunctUnion(*args, **kwargs)


def querystring(*args, **kwargs):
    return intersect(*args, **kwargs).to_string()