diff options
author | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-02 23:07:17 -0700 |
---|---|---|
committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-04 15:06:04 -0700 |
commit | be62c7a673b5f0fe973523d01e22b7ad0bb76600 (patch) | |
tree | 77dc21dfac881ad0b8e005ebd6b80852209f24da | |
parent | 2b8ede11388e81e0f6dc871a45c5327eaf456e44 (diff) | |
download | sqlparse-be62c7a673b5f0fe973523d01e22b7ad0bb76600.tar.gz |
Refactor filters
-rw-r--r-- | sqlparse/filters.py | 107 |
1 files changed, 45 insertions, 62 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 8376326..95ac74c 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -8,7 +8,7 @@ import re from sqlparse import sql, tokens as T -from sqlparse.compat import u, text_type +from sqlparse.compat import text_type from sqlparse.utils import split_unquoted_newlines @@ -16,13 +16,10 @@ from sqlparse.utils import split_unquoted_newlines # token process class _CaseFilter(object): - ttype = None def __init__(self, case=None): - if case is None: - case = 'upper' - assert case in ['lower', 'upper', 'capitalize'] + case = case or 'upper' self.convert = getattr(text_type, case) def process(self, stream): @@ -37,33 +34,35 @@ class KeywordCaseFilter(_CaseFilter): class IdentifierCaseFilter(_CaseFilter): - ttype = (T.Name, T.String.Symbol) + ttype = T.Name, T.String.Symbol def process(self, stream): for ttype, value in stream: - if ttype in self.ttype and not value.strip()[0] == '"': + if ttype in self.ttype and value.strip()[0] != '"': value = self.convert(value) yield ttype, value class TruncateStringFilter(object): - def __init__(self, width, char): - self.width = max(width, 1) - self.char = u(char) + self.width = width + self.char = char def process(self, stream): for ttype, value in stream: - if ttype is T.Literal.String.Single: - if value[:2] == '\'\'': - inner = value[2:-2] - quote = u'\'\'' - else: - inner = value[1:-1] - quote = u'\'' - if len(inner) > self.width: - value = u''.join((quote, inner[:self.width], self.char, - quote)) + if ttype != T.Literal.String.Single: + yield ttype, value + continue + + if value[:2] == "''": + inner = value[2:-2] + quote = "''" + else: + inner = value[1:-1] + quote = "'" + + if len(inner) > self.width: + value = ''.join((quote, inner[:self.width], self.char, quote)) yield ttype, value @@ -71,7 +70,6 @@ class TruncateStringFilter(object): # statement process class StripCommentsFilter(object): - def _get_next_comment(self, tlist): # TODO(andi) Comment types should be unified, see related issue38 token = tlist.token_next_by(i=sql.Comment, t=T.Comment) @@ -81,8 +79,8 @@ class StripCommentsFilter(object): token = self._get_next_comment(tlist) while token: tidx = tlist.token_index(token) - prev = tlist.token_prev(tidx, False) - next_ = tlist.token_next(tidx, False) + prev = tlist.token_prev(tidx, skip_ws=False) + next_ = tlist.token_next(tidx, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a paranthesis. if (prev is not None and next_ is not None @@ -100,7 +98,6 @@ class StripCommentsFilter(object): class StripWhitespaceFilter(object): - def _stripws(self, tlist): func_name = '_stripws_%s' % tlist.__class__.__name__.lower() func = getattr(self, func_name, self._stripws_default) @@ -122,14 +119,10 @@ class StripWhitespaceFilter(object): # Removes newlines before commas, see issue140 last_nl = None for token in tlist.tokens[:]: - if token.ttype is T.Punctuation \ - and token.value == ',' \ - and last_nl is not None: + if last_nl and token.ttype is T.Punctuation and token.value == ',': tlist.tokens.remove(last_nl) - if token.is_whitespace(): - last_nl = token - else: - last_nl = None + + last_nl = token if token.is_whitespace() else None return self._stripws_default(tlist) def _stripws_parenthesis(self, tlist): @@ -140,19 +133,13 @@ class StripWhitespaceFilter(object): self._stripws_default(tlist) def process(self, stmt, depth=0): - [self.process(sgroup, depth + 1) - for sgroup in stmt.get_sublists()] + [self.process(sgroup, depth + 1) for sgroup in stmt.get_sublists()] self._stripws(stmt) - if ( - depth == 0 - and stmt.tokens - and stmt.tokens[-1].is_whitespace() - ): + if depth == 0 and stmt.tokens and stmt.tokens[-1].is_whitespace(): stmt.tokens.pop(-1) class ReindentFilter(object): - def __init__(self, width=2, char=' ', line_width=None, wrap_after=0): self.width = width self.char = char @@ -196,8 +183,7 @@ class ReindentFilter(object): 'SET', 'BETWEEN', 'EXCEPT', 'HAVING') def _next_token(i): - t = tlist.token_next_match(i, T.Keyword, split_words, - regex=True) + t = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=i) if t and t.value.upper() == 'BETWEEN': t = _next_token(tlist.token_index(t) + 1) if t and t.value.upper() == 'AND': @@ -208,13 +194,13 @@ class ReindentFilter(object): token = _next_token(idx) added = set() while token: - prev = tlist.token_prev(tlist.token_index(token), False) + prev = tlist.token_prev(token, skip_ws=False) offset = 1 if prev and prev.is_whitespace() and prev not in added: tlist.tokens.pop(tlist.token_index(prev)) offset += 1 - uprev = u(prev) - if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))): + uprev = text_type(prev) + if prev and (uprev.endswith('\n') or uprev.endswith('\r')): nl = tlist.token_next(token) else: nl = self.nl() @@ -224,18 +210,17 @@ class ReindentFilter(object): token = _next_token(tlist.token_index(nl) + offset) def _split_statements(self, tlist): - idx = 0 - token = tlist.token_next_by_type(idx, (T.Keyword.DDL, T.Keyword.DML)) + token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML)) while token: - prev = tlist.token_prev(tlist.token_index(token), False) + prev = tlist.token_prev(token, skip_ws=False) if prev and prev.is_whitespace(): tlist.tokens.pop(tlist.token_index(prev)) # only break if it's not the first token if prev: nl = self.nl() tlist.insert_before(token, nl) - token = tlist.token_next_by_type(tlist.token_index(token) + 1, - (T.Keyword.DDL, T.Keyword.DML)) + token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML), + idx=token) def _process(self, tlist): func_name = '_process_%s' % tlist.__class__.__name__.lower() @@ -243,7 +228,7 @@ class ReindentFilter(object): func(tlist) def _process_where(self, tlist): - token = tlist.token_next_match(0, T.Keyword, 'WHERE') + token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) try: tlist.insert_before(token, self.nl()) except ValueError: # issue121, errors in statement @@ -253,7 +238,7 @@ class ReindentFilter(object): self.indent -= 1 def _process_having(self, tlist): - token = tlist.token_next_match(0, T.Keyword, 'HAVING') + token = tlist.token_next_by(m=(T.Keyword, 'HAVING')) try: tlist.insert_before(token, self.nl()) except ValueError: # issue121, errors in statement @@ -270,7 +255,7 @@ class ReindentFilter(object): tlist.tokens.insert(0, self.nl()) indented = True num_offset = self._get_offset( - tlist.token_next_match(0, T.Punctuation, '(')) + tlist.token_next_by(m=(T.Punctuation, '('))) self.offset += num_offset self._process_default(tlist, stmts=not indented) if indented: @@ -323,7 +308,7 @@ class ReindentFilter(object): self.offset -= 5 if num_offset is not None: self.offset -= num_offset - end = tlist.token_next_match(0, T.Keyword, 'END') + end = tlist.token_next_by(m=(T.Keyword, 'END')) tlist.insert_before(end, self.nl()) self.offset -= outer_offset @@ -340,7 +325,7 @@ class ReindentFilter(object): self._process(stmt) if isinstance(stmt, sql.Statement): if self._last_stmt is not None: - if u(self._last_stmt).endswith('\n'): + if text_type(self._last_stmt).endswith('\n'): nl = '\n' else: nl = '\n\n' @@ -352,7 +337,6 @@ class ReindentFilter(object): # FIXME: Doesn't work class RightMarginFilter(object): - keep_together = ( # sql.TypeCast, sql.Identifier, sql.Alias, ) @@ -368,13 +352,12 @@ class RightMarginFilter(object): self.line = '' else: self.line = token.value.splitlines()[-1] - elif (token.is_group() - and token.__class__ not in self.keep_together): + elif token.is_group() and type(token) not in self.keep_together: token.tokens = self._process(token, token.tokens) else: - val = u(token) + val = text_type(token) if len(self.line) + len(val) > self.width: - match = re.search('^ +', self.line) + match = re.search(r'^ +', self.line) if match is not None: indent = match.group() else: @@ -389,13 +372,13 @@ class RightMarginFilter(object): # group.tokens = self._process(group, group.tokens) raise NotImplementedError + # --------------------------- # postprocess class SerializerUnicode(object): - def process(self, stmt): - raw = u(stmt) + raw = text_type(stmt) lines = split_unquoted_newlines(raw) res = '\n'.join(line.rstrip() for line in lines) return res @@ -418,7 +401,7 @@ class OutputFilter(object): else: varname = self.varname - has_nl = len(u(stmt).strip().splitlines()) > 1 + has_nl = len(text_type(stmt).strip().splitlines()) > 1 stmt.tokens = self._process(stmt.tokens, varname, has_nl) return stmt |