diff options
author | Andi Albrecht <albrecht.andi@gmail.com> | 2016-05-15 21:02:17 +0200 |
---|---|---|
committer | Andi Albrecht <albrecht.andi@gmail.com> | 2016-05-15 21:02:17 +0200 |
commit | efb6fd4bdb80b985c356bb6eb996e6e25cf63b05 (patch) | |
tree | 5ca2ce45fc9cf0a35efbcf68bb643f3a12e2575d /sqlparse/sql.py | |
parent | 9ab1464ea9c1d0296d698d9637ed3e3cd92326f9 (diff) | |
parent | 955996e3e5c49fb6b7f200ceecee2f8082656ac4 (diff) | |
download | sqlparse-efb6fd4bdb80b985c356bb6eb996e6e25cf63b05.tar.gz |
Merge pull request #235 from vmuriart/refactor
Refactor
Diffstat (limited to 'sqlparse/sql.py')
-rw-r--r-- | sqlparse/sql.py | 240 |
1 files changed, 108 insertions, 132 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py index f357572..9afdac3 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -7,6 +7,7 @@ import sys from sqlparse import tokens as T from sqlparse.compat import string_types, u +from sqlparse.utils import imt, remove_quotes class Token(object): @@ -77,7 +78,7 @@ class Token(object): if regex: if isinstance(values, string_types): - values = set([values]) + values = {values} if self.ttype is T.Keyword: values = set(re.compile(v, re.IGNORECASE) for v in values) @@ -150,7 +151,7 @@ class TokenList(Token): if tokens is None: tokens = [] self.tokens = tokens - Token.__init__(self, None, self._to_string()) + super(TokenList, self).__init__(None, self.__str__()) def __unicode__(self): return self._to_string() @@ -184,14 +185,6 @@ class TokenList(Token): if (token.is_group() and (max_depth is None or depth < max_depth)): token._pprint_tree(max_depth, depth + 1) - def _remove_quotes(self, val): - """Helper that removes surrounding quotes from strings.""" - if not val: - return val - if val[0] in ('"', '\'') and val[-1] == val[0]: - val = val[1:-1] - return val - def get_token_at_offset(self, offset): """Returns the token that is on position offset.""" idx = 0 @@ -213,12 +206,12 @@ class TokenList(Token): else: yield token -# def __iter__(self): -# return self -# -# def next(self): -# for token in self.tokens: -# yield token + # def __iter__(self): + # return self + # + # def next(self): + # for token in self.tokens: + # yield token def is_group(self): return True @@ -232,6 +225,27 @@ class TokenList(Token): def _groupable_tokens(self): return self.tokens + def _token_matching(self, funcs, start=0, end=None, reverse=False): + """next token that match functions""" + if start is None: + return None + + if not isinstance(start, int): + start = self.token_index(start) + 1 + + if not isinstance(funcs, (list, tuple)): + funcs = (funcs,) + + if reverse: + iterable = iter(reversed(self.tokens[end:start - 1])) + else: + iterable = self.tokens[start:end] + + for token in iterable: + for func in funcs: + if func(token): + return token + def token_first(self, ignore_whitespace=True, ignore_comments=False): """Returns the first child token. @@ -241,12 +255,13 @@ class TokenList(Token): if *ignore_comments* is ``True`` (default: ``False``), comments are ignored too. """ - for token in self.tokens: - if ignore_whitespace and token.is_whitespace(): - continue - if ignore_comments and isinstance(token, Comment): - continue - return token + funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or + (ignore_comments and imt(tk, i=Comment))) + return self._token_matching(funcs) + + def token_next_by(self, i=None, m=None, t=None, idx=0, end=None): + funcs = lambda tk: imt(tk, i, m, t) + return self._token_matching(funcs, idx, end) def token_next_by_instance(self, idx, clss, end=None): """Returns the next token matching a class. @@ -256,48 +271,26 @@ class TokenList(Token): If no matching token can be found ``None`` is returned. """ - if not isinstance(clss, (list, tuple)): - clss = (clss,) - - for token in self.tokens[idx:end]: - if isinstance(token, clss): - return token + funcs = lambda tk: imt(tk, i=clss) + return self._token_matching(funcs, idx, end) def token_next_by_type(self, idx, ttypes): """Returns next matching token by it's token type.""" - if not isinstance(ttypes, (list, tuple)): - ttypes = [ttypes] - - for token in self.tokens[idx:]: - if token.ttype in ttypes: - return token + funcs = lambda tk: imt(tk, t=ttypes) + return self._token_matching(funcs, idx) def token_next_match(self, idx, ttype, value, regex=False): """Returns next token where it's ``match`` method returns ``True``.""" - if not isinstance(idx, int): - idx = self.token_index(idx) - - for n in range(idx, len(self.tokens)): - token = self.tokens[n] - if token.match(ttype, value, regex): - return token + funcs = lambda tk: imt(tk, m=(ttype, value, regex)) + return self._token_matching(funcs, idx) def token_not_matching(self, idx, funcs): - for token in self.tokens[idx:]: - passed = False - for func in funcs: - if func(token): - passed = True - break - - if not passed: - return token + funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs + funcs = [lambda tk: not func(tk) for func in funcs] + return self._token_matching(funcs, idx) def token_matching(self, idx, funcs): - for token in self.tokens[idx:]: - for func in funcs: - if func(token): - return token + return self._token_matching(funcs, idx) def token_prev(self, idx, skip_ws=True): """Returns the previous token relative to *idx*. @@ -305,17 +298,10 @@ class TokenList(Token): If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. ``None`` is returned if there's no previous token. """ - if idx is None: - return None - - if not isinstance(idx, int): - idx = self.token_index(idx) - - while idx: - idx -= 1 - if self.tokens[idx].is_whitespace() and skip_ws: - continue - return self.tokens[idx] + if isinstance(idx, int): + idx += 1 # alot of code usage current pre-compensates for this + funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + return self._token_matching(funcs, idx, reverse=True) def token_next(self, idx, skip_ws=True): """Returns the next token relative to *idx*. @@ -323,59 +309,56 @@ class TokenList(Token): If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. ``None`` is returned if there's no next token. """ - if idx is None: - return None - - if not isinstance(idx, int): - idx = self.token_index(idx) - - while idx < len(self.tokens) - 1: - idx += 1 - if self.tokens[idx].is_whitespace() and skip_ws: - continue - return self.tokens[idx] + if isinstance(idx, int): + idx += 1 # alot of code usage current pre-compensates for this + funcs = lambda tk: not (tk.is_whitespace() and skip_ws) + return self._token_matching(funcs, idx) def token_index(self, token, start=0): """Return list index of token.""" - if start > 0: - # Performing `index` manually is much faster when starting - # in the middle of the list of tokens and expecting to find - # the token near to the starting index. - for i in range(start, len(self.tokens)): - if self.tokens[i] == token: - return i - return -1 - return self.tokens.index(token) - - def tokens_between(self, start, end, exclude_end=False): + start = self.token_index(start) if not isinstance(start, int) else start + return start + self.tokens[start:].index(token) + + def tokens_between(self, start, end, include_end=True): """Return all tokens between (and including) start and end. - If *exclude_end* is ``True`` (default is ``False``) the end token - is included too. + If *include_end* is ``False`` (default is ``True``) the end token + is excluded. """ - # FIXME(andi): rename exclude_end to inlcude_end - if exclude_end: - offset = 0 - else: - offset = 1 - end_idx = self.token_index(end) + offset start_idx = self.token_index(start) + end_idx = include_end + self.token_index(end) return self.tokens[start_idx:end_idx] - def group_tokens(self, grp_cls, tokens, ignore_ws=False): + def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False): """Replace tokens by an instance of *grp_cls*.""" - idx = self.token_index(tokens[0]) if ignore_ws: while tokens and tokens[-1].is_whitespace(): tokens = tokens[:-1] - for t in tokens: - self.tokens.remove(t) - grp = grp_cls(tokens) + + left = tokens[0] + idx = self.token_index(left) + + if extend: + if not isinstance(left, grp_cls): + grp = grp_cls([left]) + self.tokens.remove(left) + self.tokens.insert(idx, grp) + left = grp + left.parent = self + tokens = tokens[1:] + left.tokens.extend(tokens) + left.value = left.__str__() + + else: + left = grp_cls(tokens) + left.parent = self + self.tokens.insert(idx, left) + for token in tokens: - token.parent = grp - grp.parent = self - self.tokens.insert(idx, grp) - return grp + token.parent = left + self.tokens.remove(token) + + return left def insert_before(self, where, token): """Inserts *token* before *where*.""" @@ -397,13 +380,12 @@ class TokenList(Token): """Returns the alias for this identifier or ``None``.""" # "name AS alias" - kw = self.token_next_match(0, T.Keyword, 'AS') + kw = self.token_next_by(m=(T.Keyword, 'AS')) if kw is not None: return self._get_first_name(kw, keywords=True) # "name alias" or "complicated column expression alias" - if len(self.tokens) > 2 \ - and self.token_next_by_type(0, T.Whitespace) is not None: + if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace): return self._get_first_name(reverse=True) return None @@ -440,7 +422,7 @@ class TokenList(Token): prev_ = self.token_prev(self.token_index(dot)) if prev_ is None: # something must be verry wrong here.. return None - return self._remove_quotes(prev_.value) + return remove_quotes(prev_.value) def _get_first_name(self, idx=None, reverse=False, keywords=False): """Returns the name of the first token with a name""" @@ -457,7 +439,7 @@ class TokenList(Token): for tok in tokens: if tok.ttype in types: - return self._remove_quotes(tok.value) + return remove_quotes(tok.value) elif isinstance(tok, Identifier) or isinstance(tok, Function): return tok.get_name() return None @@ -510,8 +492,6 @@ class Identifier(TokenList): Identifiers may have aliases or typecasts. """ - __slots__ = ('value', 'ttype', 'tokens') - def is_wildcard(self): """Return ``True`` if this identifier contains a wildcard.""" token = self.token_next_by_type(0, T.Wildcard) @@ -546,8 +526,6 @@ class Identifier(TokenList): class IdentifierList(TokenList): """A list of :class:`~sqlparse.sql.Identifier`\'s.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_identifiers(self): """Returns the identifiers. @@ -560,7 +538,8 @@ class IdentifierList(TokenList): class Parenthesis(TokenList): """Tokens between parenthesis.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Punctuation, '(') + M_CLOSE = (T.Punctuation, ')') @property def _groupable_tokens(self): @@ -569,8 +548,8 @@ class Parenthesis(TokenList): class SquareBrackets(TokenList): """Tokens between square brackets""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Punctuation, '[') + M_CLOSE = (T.Punctuation, ']') @property def _groupable_tokens(self): @@ -579,22 +558,22 @@ class SquareBrackets(TokenList): class Assignment(TokenList): """An assignment like 'var := val;'""" - __slots__ = ('value', 'ttype', 'tokens') class If(TokenList): """An 'if' clause with possible 'else if' or 'else' parts.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'IF') + M_CLOSE = (T.Keyword, 'END IF') class For(TokenList): """A 'FOR' loop.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, ('FOR', 'FOREACH')) + M_CLOSE = (T.Keyword, 'END LOOP') class Comparison(TokenList): """A comparison used for example in WHERE clauses.""" - __slots__ = ('value', 'ttype', 'tokens') @property def left(self): @@ -607,7 +586,6 @@ class Comparison(TokenList): class Comment(TokenList): """A comment.""" - __slots__ = ('value', 'ttype', 'tokens') def is_multiline(self): return self.tokens and self.tokens[0].ttype == T.Comment.Multiline @@ -615,13 +593,15 @@ class Comment(TokenList): class Where(TokenList): """A WHERE clause.""" - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'WHERE') + M_CLOSE = (T.Keyword, + ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING')) class Case(TokenList): """A CASE statement with one or more WHEN and possibly an ELSE part.""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'CASE') + M_CLOSE = (T.Keyword, 'END') def get_cases(self): """Returns a list of 2-tuples (condition, value). @@ -671,22 +651,18 @@ class Case(TokenList): class Function(TokenList): """A function or procedure call.""" - __slots__ = ('value', 'ttype', 'tokens') - def get_parameters(self): """Return a list of parameters.""" parenthesis = self.tokens[-1] for t in parenthesis.tokens: - if isinstance(t, IdentifierList): + if imt(t, i=IdentifierList): return t.get_identifiers() - elif (isinstance(t, Identifier) or - isinstance(t, Function) or - t.ttype in T.Literal): + elif imt(t, i=(Function, Identifier), t=T.Literal): return [t, ] return [] class Begin(TokenList): """A BEGIN/END block.""" - - __slots__ = ('value', 'ttype', 'tokens') + M_OPEN = (T.Keyword, 'BEGIN') + M_CLOSE = (T.Keyword, 'END') |