summaryrefslogtreecommitdiff
path: root/sqlparse/utils.py
blob: 66dd8bc796c863122023e2421a3ccc9a27743629 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import itertools
import re
from collections import OrderedDict, deque
from contextlib import contextmanager


class Cache(OrderedDict):
    """Cache with LRU algorithm using an OrderedDict as basis
    """

    def __init__(self, maxsize=100):
        OrderedDict.__init__(self)

        self._maxsize = maxsize

    def __getitem__(self, key, *args, **kwargs):
        # Get the key and remove it from the cache, or raise KeyError
        value = OrderedDict.__getitem__(self, key)
        del self[key]

        # Insert the (key, value) pair on the front of the cache
        OrderedDict.__setitem__(self, key, value)

        # Return the value from the cache
        return value

    def __setitem__(self, key, value, *args, **kwargs):
        # Key was inserted before, remove it so we put it at front later
        if key in self:
            del self[key]

        # Too much items on the cache, remove the least recent used
        elif len(self) >= self._maxsize:
            self.popitem(False)

        # Insert the (key, value) pair on the front of the cache
        OrderedDict.__setitem__(self, key, value, *args, **kwargs)


def memoize_generator(func):
    """Memoize decorator for generators

    Store `func` results in a cache according to their arguments as 'memoize'
    does but instead this works on decorators instead of regular functions.
    Obviusly, this is only useful if the generator will always return the same
    values for each specific parameters...
    """
    cache = Cache()

    def wrapped_func(*args, **kwargs):
        params = (args, tuple(sorted(kwargs.items())))

        # Look if cached
        try:
            cached = cache[params]

        # Not cached, exec and store it
        except KeyError:
            cached = []

            for item in func(*args, **kwargs):
                cached.append(item)
                yield item

            cache[params] = cached

        # Cached, yield its items
        else:
            for item in cached:
                yield item

    return wrapped_func


# This regular expression replaces the home-cooked parser that was here before.
# It is much faster, but requires an extra post-processing step to get the
# desired results (that are compatible with what you would expect from the
# str.splitlines() method).
#
# It matches groups of characters: newlines, quoted strings, or unquoted text,
# and splits on that basis. The post-processing step puts those back together
# into the actual lines of SQL.
SPLIT_REGEX = re.compile(r"""
(
 (?:                     # Start of non-capturing group
  (?:\r\n|\r|\n)      |  # Match any single newline, or
  [^\r\n'"]+          |  # Match any character series without quotes or
                         # newlines, or
  "(?:[^"\\]|\\.)*"   |  # Match double-quoted strings, or
  '(?:[^'\\]|\\.)*'      # Match single quoted strings
 )
)
""", re.VERBOSE)

LINE_MATCH = re.compile(r'(\r\n|\r|\n)')


def split_unquoted_newlines(text):
    """Split a string on all unquoted newlines.

    Unlike str.splitlines(), this will ignore CR/LF/CR+LF if the requisite
    character is inside of a string."""
    lines = SPLIT_REGEX.split(text)
    outputlines = ['']
    for line in lines:
        if not line:
            continue
        elif LINE_MATCH.match(line):
            outputlines.append('')
        else:
            outputlines[-1] += line
    return outputlines


def remove_quotes(val):
    """Helper that removes surrounding quotes from strings."""
    if val is None:
        return
    if val[0] in ('"', "'") and val[0] == val[-1]:
        val = val[1:-1]
    return val


def recurse(*cls):
    """Function decorator to help with recursion

    :param cls: Classes to not recurse over
    :return: function
    """
    def wrap(f):
        def wrapped_f(tlist):
            for sgroup in tlist.get_sublists():
                if not isinstance(sgroup, cls):
                    wrapped_f(sgroup)
            f(tlist)

        return wrapped_f

    return wrap


def imt(token, i=None, m=None, t=None):
    """Aid function to refactor comparisons for Instance, Match and TokenType
    Aid fun
    :param token:
    :param i: Class or Tuple/List of Classes
    :param m: Tuple of TokenType & Value. Can be list of Tuple for multiple
    :param t: TokenType or Tuple/List of TokenTypes
    :return:  bool
    """
    t = (t,) if t and not isinstance(t, (list, tuple)) else t
    m = (m,) if m and not isinstance(m, (list,)) else m

    if token is None:
        return False
    elif i is not None and isinstance(token, i):
        return True
    elif m is not None and any((token.match(*x) for x in m)):
        return True
    elif t is not None and token.ttype in t:
        return True
    else:
        return False


def find_matching(tlist, token, M1, M2):
    idx = tlist.token_index(token)
    depth = 0
    for token in tlist[idx:]:
        if token.match(*M1):
            depth += 1
        elif token.match(*M2):
            depth -= 1
            if depth == 0:
                return token


def consume(iterator, n):
    """Advance the iterator n-steps ahead. If n is none, consume entirely."""
    deque(itertools.islice(iterator, n), maxlen=0)


@contextmanager
def offset(filter_, n=0):
    filter_.offset += n
    yield
    filter_.offset -= n


@contextmanager
def indent(filter_, n=1):
    filter_.indent += n
    yield
    filter_.indent -= n