summaryrefslogtreecommitdiff
path: root/sqlparse/sql.py
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2016-05-15 21:02:17 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2016-05-15 21:02:17 +0200
commitefb6fd4bdb80b985c356bb6eb996e6e25cf63b05 (patch)
tree5ca2ce45fc9cf0a35efbcf68bb643f3a12e2575d /sqlparse/sql.py
parent9ab1464ea9c1d0296d698d9637ed3e3cd92326f9 (diff)
parent955996e3e5c49fb6b7f200ceecee2f8082656ac4 (diff)
downloadsqlparse-efb6fd4bdb80b985c356bb6eb996e6e25cf63b05.tar.gz
Merge pull request #235 from vmuriart/refactor
Refactor
Diffstat (limited to 'sqlparse/sql.py')
-rw-r--r--sqlparse/sql.py240
1 files changed, 108 insertions, 132 deletions
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index f357572..9afdac3 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -7,6 +7,7 @@ import sys
from sqlparse import tokens as T
from sqlparse.compat import string_types, u
+from sqlparse.utils import imt, remove_quotes
class Token(object):
@@ -77,7 +78,7 @@ class Token(object):
if regex:
if isinstance(values, string_types):
- values = set([values])
+ values = {values}
if self.ttype is T.Keyword:
values = set(re.compile(v, re.IGNORECASE) for v in values)
@@ -150,7 +151,7 @@ class TokenList(Token):
if tokens is None:
tokens = []
self.tokens = tokens
- Token.__init__(self, None, self._to_string())
+ super(TokenList, self).__init__(None, self.__str__())
def __unicode__(self):
return self._to_string()
@@ -184,14 +185,6 @@ class TokenList(Token):
if (token.is_group() and (max_depth is None or depth < max_depth)):
token._pprint_tree(max_depth, depth + 1)
- def _remove_quotes(self, val):
- """Helper that removes surrounding quotes from strings."""
- if not val:
- return val
- if val[0] in ('"', '\'') and val[-1] == val[0]:
- val = val[1:-1]
- return val
-
def get_token_at_offset(self, offset):
"""Returns the token that is on position offset."""
idx = 0
@@ -213,12 +206,12 @@ class TokenList(Token):
else:
yield token
-# def __iter__(self):
-# return self
-#
-# def next(self):
-# for token in self.tokens:
-# yield token
+ # def __iter__(self):
+ # return self
+ #
+ # def next(self):
+ # for token in self.tokens:
+ # yield token
def is_group(self):
return True
@@ -232,6 +225,27 @@ class TokenList(Token):
def _groupable_tokens(self):
return self.tokens
+ def _token_matching(self, funcs, start=0, end=None, reverse=False):
+ """next token that match functions"""
+ if start is None:
+ return None
+
+ if not isinstance(start, int):
+ start = self.token_index(start) + 1
+
+ if not isinstance(funcs, (list, tuple)):
+ funcs = (funcs,)
+
+ if reverse:
+ iterable = iter(reversed(self.tokens[end:start - 1]))
+ else:
+ iterable = self.tokens[start:end]
+
+ for token in iterable:
+ for func in funcs:
+ if func(token):
+ return token
+
def token_first(self, ignore_whitespace=True, ignore_comments=False):
"""Returns the first child token.
@@ -241,12 +255,13 @@ class TokenList(Token):
if *ignore_comments* is ``True`` (default: ``False``), comments are
ignored too.
"""
- for token in self.tokens:
- if ignore_whitespace and token.is_whitespace():
- continue
- if ignore_comments and isinstance(token, Comment):
- continue
- return token
+ funcs = lambda tk: not ((ignore_whitespace and tk.is_whitespace()) or
+ (ignore_comments and imt(tk, i=Comment)))
+ return self._token_matching(funcs)
+
+ def token_next_by(self, i=None, m=None, t=None, idx=0, end=None):
+ funcs = lambda tk: imt(tk, i, m, t)
+ return self._token_matching(funcs, idx, end)
def token_next_by_instance(self, idx, clss, end=None):
"""Returns the next token matching a class.
@@ -256,48 +271,26 @@ class TokenList(Token):
If no matching token can be found ``None`` is returned.
"""
- if not isinstance(clss, (list, tuple)):
- clss = (clss,)
-
- for token in self.tokens[idx:end]:
- if isinstance(token, clss):
- return token
+ funcs = lambda tk: imt(tk, i=clss)
+ return self._token_matching(funcs, idx, end)
def token_next_by_type(self, idx, ttypes):
"""Returns next matching token by it's token type."""
- if not isinstance(ttypes, (list, tuple)):
- ttypes = [ttypes]
-
- for token in self.tokens[idx:]:
- if token.ttype in ttypes:
- return token
+ funcs = lambda tk: imt(tk, t=ttypes)
+ return self._token_matching(funcs, idx)
def token_next_match(self, idx, ttype, value, regex=False):
"""Returns next token where it's ``match`` method returns ``True``."""
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- for n in range(idx, len(self.tokens)):
- token = self.tokens[n]
- if token.match(ttype, value, regex):
- return token
+ funcs = lambda tk: imt(tk, m=(ttype, value, regex))
+ return self._token_matching(funcs, idx)
def token_not_matching(self, idx, funcs):
- for token in self.tokens[idx:]:
- passed = False
- for func in funcs:
- if func(token):
- passed = True
- break
-
- if not passed:
- return token
+ funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs
+ funcs = [lambda tk: not func(tk) for func in funcs]
+ return self._token_matching(funcs, idx)
def token_matching(self, idx, funcs):
- for token in self.tokens[idx:]:
- for func in funcs:
- if func(token):
- return token
+ return self._token_matching(funcs, idx)
def token_prev(self, idx, skip_ws=True):
"""Returns the previous token relative to *idx*.
@@ -305,17 +298,10 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no previous token.
"""
- if idx is None:
- return None
-
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- while idx:
- idx -= 1
- if self.tokens[idx].is_whitespace() and skip_ws:
- continue
- return self.tokens[idx]
+ if isinstance(idx, int):
+ idx += 1 # alot of code usage current pre-compensates for this
+ funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ return self._token_matching(funcs, idx, reverse=True)
def token_next(self, idx, skip_ws=True):
"""Returns the next token relative to *idx*.
@@ -323,59 +309,56 @@ class TokenList(Token):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no next token.
"""
- if idx is None:
- return None
-
- if not isinstance(idx, int):
- idx = self.token_index(idx)
-
- while idx < len(self.tokens) - 1:
- idx += 1
- if self.tokens[idx].is_whitespace() and skip_ws:
- continue
- return self.tokens[idx]
+ if isinstance(idx, int):
+ idx += 1 # alot of code usage current pre-compensates for this
+ funcs = lambda tk: not (tk.is_whitespace() and skip_ws)
+ return self._token_matching(funcs, idx)
def token_index(self, token, start=0):
"""Return list index of token."""
- if start > 0:
- # Performing `index` manually is much faster when starting
- # in the middle of the list of tokens and expecting to find
- # the token near to the starting index.
- for i in range(start, len(self.tokens)):
- if self.tokens[i] == token:
- return i
- return -1
- return self.tokens.index(token)
-
- def tokens_between(self, start, end, exclude_end=False):
+ start = self.token_index(start) if not isinstance(start, int) else start
+ return start + self.tokens[start:].index(token)
+
+ def tokens_between(self, start, end, include_end=True):
"""Return all tokens between (and including) start and end.
- If *exclude_end* is ``True`` (default is ``False``) the end token
- is included too.
+ If *include_end* is ``False`` (default is ``True``) the end token
+ is excluded.
"""
- # FIXME(andi): rename exclude_end to inlcude_end
- if exclude_end:
- offset = 0
- else:
- offset = 1
- end_idx = self.token_index(end) + offset
start_idx = self.token_index(start)
+ end_idx = include_end + self.token_index(end)
return self.tokens[start_idx:end_idx]
- def group_tokens(self, grp_cls, tokens, ignore_ws=False):
+ def group_tokens(self, grp_cls, tokens, ignore_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
- idx = self.token_index(tokens[0])
if ignore_ws:
while tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]
- for t in tokens:
- self.tokens.remove(t)
- grp = grp_cls(tokens)
+
+ left = tokens[0]
+ idx = self.token_index(left)
+
+ if extend:
+ if not isinstance(left, grp_cls):
+ grp = grp_cls([left])
+ self.tokens.remove(left)
+ self.tokens.insert(idx, grp)
+ left = grp
+ left.parent = self
+ tokens = tokens[1:]
+ left.tokens.extend(tokens)
+ left.value = left.__str__()
+
+ else:
+ left = grp_cls(tokens)
+ left.parent = self
+ self.tokens.insert(idx, left)
+
for token in tokens:
- token.parent = grp
- grp.parent = self
- self.tokens.insert(idx, grp)
- return grp
+ token.parent = left
+ self.tokens.remove(token)
+
+ return left
def insert_before(self, where, token):
"""Inserts *token* before *where*."""
@@ -397,13 +380,12 @@ class TokenList(Token):
"""Returns the alias for this identifier or ``None``."""
# "name AS alias"
- kw = self.token_next_match(0, T.Keyword, 'AS')
+ kw = self.token_next_by(m=(T.Keyword, 'AS'))
if kw is not None:
return self._get_first_name(kw, keywords=True)
# "name alias" or "complicated column expression alias"
- if len(self.tokens) > 2 \
- and self.token_next_by_type(0, T.Whitespace) is not None:
+ if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace):
return self._get_first_name(reverse=True)
return None
@@ -440,7 +422,7 @@ class TokenList(Token):
prev_ = self.token_prev(self.token_index(dot))
if prev_ is None: # something must be verry wrong here..
return None
- return self._remove_quotes(prev_.value)
+ return remove_quotes(prev_.value)
def _get_first_name(self, idx=None, reverse=False, keywords=False):
"""Returns the name of the first token with a name"""
@@ -457,7 +439,7 @@ class TokenList(Token):
for tok in tokens:
if tok.ttype in types:
- return self._remove_quotes(tok.value)
+ return remove_quotes(tok.value)
elif isinstance(tok, Identifier) or isinstance(tok, Function):
return tok.get_name()
return None
@@ -510,8 +492,6 @@ class Identifier(TokenList):
Identifiers may have aliases or typecasts.
"""
- __slots__ = ('value', 'ttype', 'tokens')
-
def is_wildcard(self):
"""Return ``True`` if this identifier contains a wildcard."""
token = self.token_next_by_type(0, T.Wildcard)
@@ -546,8 +526,6 @@ class Identifier(TokenList):
class IdentifierList(TokenList):
"""A list of :class:`~sqlparse.sql.Identifier`\'s."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_identifiers(self):
"""Returns the identifiers.
@@ -560,7 +538,8 @@ class IdentifierList(TokenList):
class Parenthesis(TokenList):
"""Tokens between parenthesis."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Punctuation, '(')
+ M_CLOSE = (T.Punctuation, ')')
@property
def _groupable_tokens(self):
@@ -569,8 +548,8 @@ class Parenthesis(TokenList):
class SquareBrackets(TokenList):
"""Tokens between square brackets"""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Punctuation, '[')
+ M_CLOSE = (T.Punctuation, ']')
@property
def _groupable_tokens(self):
@@ -579,22 +558,22 @@ class SquareBrackets(TokenList):
class Assignment(TokenList):
"""An assignment like 'var := val;'"""
- __slots__ = ('value', 'ttype', 'tokens')
class If(TokenList):
"""An 'if' clause with possible 'else if' or 'else' parts."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'IF')
+ M_CLOSE = (T.Keyword, 'END IF')
class For(TokenList):
"""A 'FOR' loop."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, ('FOR', 'FOREACH'))
+ M_CLOSE = (T.Keyword, 'END LOOP')
class Comparison(TokenList):
"""A comparison used for example in WHERE clauses."""
- __slots__ = ('value', 'ttype', 'tokens')
@property
def left(self):
@@ -607,7 +586,6 @@ class Comparison(TokenList):
class Comment(TokenList):
"""A comment."""
- __slots__ = ('value', 'ttype', 'tokens')
def is_multiline(self):
return self.tokens and self.tokens[0].ttype == T.Comment.Multiline
@@ -615,13 +593,15 @@ class Comment(TokenList):
class Where(TokenList):
"""A WHERE clause."""
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'WHERE')
+ M_CLOSE = (T.Keyword,
+ ('ORDER', 'GROUP', 'LIMIT', 'UNION', 'EXCEPT', 'HAVING'))
class Case(TokenList):
"""A CASE statement with one or more WHEN and possibly an ELSE part."""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'CASE')
+ M_CLOSE = (T.Keyword, 'END')
def get_cases(self):
"""Returns a list of 2-tuples (condition, value).
@@ -671,22 +651,18 @@ class Case(TokenList):
class Function(TokenList):
"""A function or procedure call."""
- __slots__ = ('value', 'ttype', 'tokens')
-
def get_parameters(self):
"""Return a list of parameters."""
parenthesis = self.tokens[-1]
for t in parenthesis.tokens:
- if isinstance(t, IdentifierList):
+ if imt(t, i=IdentifierList):
return t.get_identifiers()
- elif (isinstance(t, Identifier) or
- isinstance(t, Function) or
- t.ttype in T.Literal):
+ elif imt(t, i=(Function, Identifier), t=T.Literal):
return [t, ]
return []
class Begin(TokenList):
"""A BEGIN/END block."""
-
- __slots__ = ('value', 'ttype', 'tokens')
+ M_OPEN = (T.Keyword, 'BEGIN')
+ M_CLOSE = (T.Keyword, 'END')