diff options
author | Vik <vmuriart@users.noreply.github.com> | 2016-06-09 20:49:04 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-09 20:49:04 -0700 |
commit | 00304afc15a554f2ac8decca1d916ba66c143b45 (patch) | |
tree | c2fbfe7132df666de91ada03f31acf82985b2ab6 /sqlparse | |
parent | 3e7803e70d4c2546ebef566b42e4aeec12ee81a8 (diff) | |
parent | ba0100156fae959be5b44e69e95397113668c94d (diff) | |
download | sqlparse-00304afc15a554f2ac8decca1d916ba66c143b45.tar.gz |
Merge pull request #253 from vmuriart/refactor_filters
Refactor filters
Diffstat (limited to 'sqlparse')
-rw-r--r-- | sqlparse/filters/aligned_indent.py | 181 | ||||
-rw-r--r-- | sqlparse/filters/others.py | 95 | ||||
-rw-r--r-- | sqlparse/filters/reindent.py | 232 |
3 files changed, 213 insertions, 295 deletions
diff --git a/sqlparse/filters/aligned_indent.py b/sqlparse/filters/aligned_indent.py index 7f7557e..ea749e9 100644 --- a/sqlparse/filters/aligned_indent.py +++ b/sqlparse/filters/aligned_indent.py @@ -6,6 +6,8 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php from sqlparse import sql, tokens as T +from sqlparse.compat import text_type +from sqlparse.utils import offset, indent class AlignedIndentFilter(object): @@ -19,139 +21,106 @@ class AlignedIndentFilter(object): 'ORDER', 'UNION', 'VALUES', 'SET', 'BETWEEN', 'EXCEPT') - def __init__(self, char=' ', line_width=None): + def __init__(self, char=' ', n='\n'): + self.n = n + self.offset = 0 + self.indent = 0 self.char = char self._max_kwd_len = len('select') - def newline(self): - return sql.Token(T.Newline, '\n') + def nl(self, offset=1): + # offset = 1 represent a single space after SELECT + offset = -len(offset) if not isinstance(offset, int) else offset + # add two for the space and parens + indent = self.indent * (2 + self._max_kwd_len) - def whitespace(self, chars=0, newline_before=False, newline_after=False): - return sql.Token(T.Whitespace, ('\n' if newline_before else '') + - self.char * chars + ('\n' if newline_after else '')) + return sql.Token(T.Whitespace, self.n + self.char * ( + self._max_kwd_len + offset + indent + self.offset)) - def _process_statement(self, tlist, base_indent=0): - if tlist.tokens[0].is_whitespace() and base_indent == 0: + def _process_statement(self, tlist): + if tlist.tokens[0].is_whitespace() and self.indent == 0: tlist.tokens.pop(0) # process the main query body - return self._process(sql.TokenList(tlist.tokens), - base_indent=base_indent) + self._process(sql.TokenList(tlist.tokens)) - def _process_parenthesis(self, tlist, base_indent=0): - if not tlist.token_next_by(m=(T.DML, 'SELECT')): - # if this isn't a subquery, don't re-indent - return tlist + def _process_parenthesis(self, tlist): + # if this isn't a subquery, don't re-indent + if tlist.token_next_by(m=(T.DML, 'SELECT')): + with indent(self): + tlist.insert_after(tlist[0], self.nl('SELECT')) + # process the inside of the parantheses + self._process_default(tlist) - # add two for the space and parens - sub_indent = base_indent + self._max_kwd_len + 2 - tlist.insert_after(tlist.tokens[0], - self.whitespace(sub_indent, newline_before=True)) - # de-indent the last parenthesis - tlist.insert_before(tlist.tokens[-1], - self.whitespace(sub_indent - 1, - newline_before=True)) - - # process the inside of the parantheses - tlist.tokens = ( - [tlist.tokens[0]] + - self._process(sql.TokenList(tlist._groupable_tokens), - base_indent=sub_indent).tokens + - [tlist.tokens[-1]] - ) - return tlist - - def _process_identifierlist(self, tlist, base_indent=0): + # de-indent last parenthesis + tlist.insert_before(tlist[-1], self.nl()) + + def _process_identifierlist(self, tlist): # columns being selected - new_tokens = [] - identifiers = list(filter( - lambda t: t.ttype not in (T.Punctuation, T.Whitespace, T.Newline), - tlist.tokens)) - for i, token in enumerate(identifiers): - if i > 0: - new_tokens.append(self.newline()) - new_tokens.append( - self.whitespace(self._max_kwd_len + base_indent + 1)) - new_tokens.append(token) - if i < len(identifiers) - 1: - # if not last column in select, add a comma seperator - new_tokens.append(sql.Token(T.Punctuation, ',')) - tlist.tokens = new_tokens - - # process any sub-sub statements (like case statements) - for sgroup in tlist.get_sublists(): - self._process(sgroup, base_indent=base_indent) - return tlist + identifiers = list(tlist.get_identifiers()) + identifiers.pop(0) + [tlist.insert_before(token, self.nl()) for token in identifiers] + self._process_default(tlist) - def _process_case(self, tlist, base_indent=0): - base_offset = base_indent + self._max_kwd_len + len('case ') - case_offset = len('when ') + def _process_case(self, tlist): + offset_ = len('case ') + len('when ') cases = tlist.get_cases(skip_ws=True) # align the end as well end_token = tlist.token_next_by(m=(T.Keyword, 'END')) cases.append((None, [end_token])) - condition_width = max( - len(' '.join(map(str, cond))) for cond, value in cases if cond) + condition_width = [len(' '.join(map(text_type, cond))) if cond else 0 + for cond, _ in cases] + max_cond_width = max(condition_width) + for i, (cond, value) in enumerate(cases): - if cond is None: # else or end - stmt = value[0] - line = value - else: - stmt = cond[0] - line = cond + value + # cond is None when 'else or end' + stmt = cond[0] if cond else value[0] + if i > 0: - tlist.insert_before(stmt, self.whitespace( - base_offset + case_offset - len(str(stmt)))) + tlist.insert_before(stmt, self.nl( + offset_ - len(text_type(stmt)))) if cond: - tlist.insert_after(cond[-1], self.whitespace( - condition_width - len(' '.join(map(str, cond))))) - - if i < len(cases) - 1: - # if not the END add a newline - tlist.insert_after(line[-1], self.newline()) - - def _process_substatement(self, tlist, base_indent=0): - def _next_token(i): - t = tlist.token_next_by(m=(T.Keyword, self.split_words, True), - idx=i) - # treat "BETWEEN x and y" as a single statement - if t and t.value.upper() == 'BETWEEN': - t = _next_token(tlist.token_index(t) + 1) - if t and t.value.upper() == 'AND': - t = _next_token(tlist.token_index(t) + 1) - return t - - idx = 0 - token = _next_token(idx) + ws = sql.Token(T.Whitespace, self.char * ( + max_cond_width - condition_width[i])) + tlist.insert_after(cond[-1], ws) + + def _next_token(self, tlist, idx=0): + split_words = T.Keyword, self.split_words, True + token = tlist.token_next_by(m=split_words, idx=idx) + # treat "BETWEEN x and y" as a single statement + if token and token.value.upper() == 'BETWEEN': + token = self._next_token(tlist, token) + if token and token.value.upper() == 'AND': + token = self._next_token(tlist, token) + return token + + def _split_kwds(self, tlist): + token = self._next_token(tlist) while token: # joins are special case. only consider the first word as aligner if token.match(T.Keyword, self.join_words, regex=True): - token_indent = len(token.value.split()[0]) + token_indent = token.value.split()[0] else: - token_indent = len(str(token)) - tlist.insert_before(token, self.whitespace( - self._max_kwd_len - token_indent + base_indent, - newline_before=True)) - next_idx = tlist.token_index(token) + 1 - token = _next_token(next_idx) + token_indent = text_type(token) + tlist.insert_before(token, self.nl(token_indent)) + token = self._next_token(tlist, token) + def _process_default(self, tlist): + self._split_kwds(tlist) # process any sub-sub statements for sgroup in tlist.get_sublists(): - prev_token = tlist.token_prev(tlist.token_index(sgroup)) - indent_offset = 0 - # HACK: make "group/order by" work. Longer than _max_kwd_len. - if prev_token and prev_token.match(T.Keyword, 'BY'): - # TODO: generalize this - indent_offset = 3 - self._process(sgroup, base_indent=base_indent + indent_offset) - return tlist - - def _process(self, tlist, base_indent=0): - token_name = tlist.__class__.__name__.lower() - func_name = '_process_%s' % token_name - func = getattr(self, func_name, self._process_substatement) - return func(tlist, base_indent=base_indent) + prev = tlist.token_prev(sgroup) + # HACK: make "group/order by" work. Longer than max_len. + offset_ = 3 if (prev and prev.match(T.Keyword, 'BY')) else 0 + with offset(self, offset_): + self._process(sgroup) + + def _process(self, tlist): + func_name = '_process_{cls}'.format(cls=type(tlist).__name__) + func = getattr(self, func_name.lower(), self._process_default) + func(tlist) def process(self, stmt): self._process(stmt) + return stmt diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py index 6132f9a..6951c74 100644 --- a/sqlparse/filters/others.py +++ b/sqlparse/filters/others.py @@ -11,59 +11,62 @@ from sqlparse.utils import split_unquoted_newlines class StripCommentsFilter(object): - def _get_next_comment(self, tlist): - # TODO(andi) Comment types should be unified, see related issue38 - token = tlist.token_next_by(i=sql.Comment, t=T.Comment) - return token + @staticmethod + def _process(tlist): + def get_next_comment(): + # TODO(andi) Comment types should be unified, see related issue38 + return tlist.token_next_by(i=sql.Comment, t=T.Comment) - def _process(self, tlist): - token = self._get_next_comment(tlist) + token = get_next_comment() while token: - tidx = tlist.token_index(token) - prev = tlist.token_prev(tidx, skip_ws=False) - next_ = tlist.token_next(tidx, skip_ws=False) + prev = tlist.token_prev(token, skip_ws=False) + next_ = tlist.token_next(token, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a paranthesis. - if (prev is not None and next_ is not None - and not prev.is_whitespace() and not next_.is_whitespace() - and not (prev.match(T.Punctuation, '(') - or next_.match(T.Punctuation, ')'))): - tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ') + if (prev is None or next_ is None or + prev.is_whitespace() or prev.match(T.Punctuation, '(') or + next_.is_whitespace() or next_.match(T.Punctuation, ')')): + tlist.tokens.remove(token) else: - tlist.tokens.pop(tidx) - token = self._get_next_comment(tlist) + tidx = tlist.token_index(token) + tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ') + + token = get_next_comment() def process(self, stmt): [self.process(sgroup) for sgroup in stmt.get_sublists()] - self._process(stmt) + StripCommentsFilter._process(stmt) + return stmt class StripWhitespaceFilter(object): def _stripws(self, tlist): - func_name = '_stripws_%s' % tlist.__class__.__name__.lower() - func = getattr(self, func_name, self._stripws_default) + func_name = '_stripws_{cls}'.format(cls=type(tlist).__name__) + func = getattr(self, func_name.lower(), self._stripws_default) func(tlist) - def _stripws_default(self, tlist): + @staticmethod + def _stripws_default(tlist): last_was_ws = False is_first_char = True for token in tlist.tokens: if token.is_whitespace(): - if last_was_ws or is_first_char: - token.value = '' - else: - token.value = ' ' + token.value = '' if last_was_ws or is_first_char else ' ' last_was_ws = token.is_whitespace() is_first_char = False def _stripws_identifierlist(self, tlist): # Removes newlines before commas, see issue140 last_nl = None - for token in tlist.tokens[:]: + for token in list(tlist.tokens): if last_nl and token.ttype is T.Punctuation and token.value == ',': tlist.tokens.remove(last_nl) - last_nl = token if token.is_whitespace() else None + + # next_ = tlist.token_next(token, skip_ws=False) + # if (next_ and not next_.is_whitespace() and + # token.ttype is T.Punctuation and token.value == ','): + # tlist.insert_after(token, sql.Token(T.Whitespace, ' ')) return self._stripws_default(tlist) def _stripws_parenthesis(self, tlist): @@ -78,43 +81,39 @@ class StripWhitespaceFilter(object): self._stripws(stmt) if depth == 0 and stmt.tokens and stmt.tokens[-1].is_whitespace(): stmt.tokens.pop(-1) + return stmt class SpacesAroundOperatorsFilter(object): - whitelist = (sql.Identifier, sql.Comparison, sql.Where) - - def _process(self, tlist): - def next_token(idx): + @staticmethod + def _process(tlist): + def next_token(idx=0): return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx) - idx = 0 - token = next_token(idx) + token = next_token() while token: - idx = tlist.token_index(token) - if idx > 0 and tlist.tokens[idx - 1].ttype != T.Whitespace: - # insert before - tlist.tokens.insert(idx, sql.Token(T.Whitespace, ' ')) - idx += 1 - if idx < len(tlist.tokens) - 1: - if tlist.tokens[idx + 1].ttype != T.Whitespace: - tlist.tokens.insert(idx + 1, sql.Token(T.Whitespace, ' ')) + prev_ = tlist.token_prev(token, skip_ws=False) + if prev_ and prev_.ttype != T.Whitespace: + tlist.insert_before(token, sql.Token(T.Whitespace, ' ')) - idx += 1 - token = next_token(idx) + next_ = tlist.token_next(token, skip_ws=False) + if next_ and next_.ttype != T.Whitespace: + tlist.insert_after(token, sql.Token(T.Whitespace, ' ')) - for sgroup in tlist.get_sublists(): - self._process(sgroup) + token = next_token(idx=token) def process(self, stmt): - self._process(stmt) + [self.process(sgroup) for sgroup in stmt.get_sublists()] + SpacesAroundOperatorsFilter._process(stmt) + return stmt # --------------------------- # postprocess class SerializerUnicode(object): - def process(self, stmt): + @staticmethod + def process(stmt): raw = text_type(stmt) lines = split_unquoted_newlines(raw) - res = '\n'.join(line.rstrip() for line in lines) - return res + return '\n'.join(line.rstrip() for line in lines) diff --git a/sqlparse/filters/reindent.py b/sqlparse/filters/reindent.py index f9c225f..f7ddfc9 100644 --- a/sqlparse/filters/reindent.py +++ b/sqlparse/filters/reindent.py @@ -7,199 +7,149 @@ from sqlparse import sql, tokens as T from sqlparse.compat import text_type +from sqlparse.utils import offset, indent class ReindentFilter(object): - def __init__(self, width=2, char=' ', line_width=None, wrap_after=0): + def __init__(self, width=2, char=' ', wrap_after=0, n='\n'): + self.n = n self.width = width self.char = char self.indent = 0 self.offset = 0 - self.line_width = line_width self.wrap_after = wrap_after self._curr_stmt = None self._last_stmt = None def _flatten_up_to_token(self, token): - """Yields all tokens up to token plus the next one.""" - # helper for _get_offset - iterator = self._curr_stmt.flatten() - for t in iterator: - yield t + """Yields all tokens up to token but excluding current.""" + if token.is_group(): + token = next(token.flatten()) + + for t in self._curr_stmt.flatten(): if t == token: raise StopIteration + yield t + + @property + def leading_ws(self): + return self.offset + self.indent * self.width def _get_offset(self, token): raw = ''.join(map(text_type, self._flatten_up_to_token(token))) - line = raw.splitlines()[-1] + line = (raw or '\n').splitlines()[-1] # Now take current offset into account and return relative offset. - full_offset = len(line) - len(self.char * (self.width * self.indent)) - return full_offset - self.offset + return len(line) - len(self.char * self.leading_ws) def nl(self): - # TODO: newline character should be configurable - space = (self.char * ((self.indent * self.width) + self.offset)) - # Detect runaway indenting due to parsing errors - if len(space) > 200: - # something seems to be wrong, flip back - self.indent = self.offset = 0 - space = (self.char * ((self.indent * self.width) + self.offset)) - ws = '\n' + space - return sql.Token(T.Whitespace, ws) + return sql.Token(T.Whitespace, self.n + self.char * self.leading_ws) - def _split_kwds(self, tlist): + def _next_token(self, tlist, idx=0): split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR', 'GROUP', 'ORDER', 'UNION', 'VALUES', 'SET', 'BETWEEN', 'EXCEPT', 'HAVING') + token = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=idx) + + if token and token.value.upper() == 'BETWEEN': + token = self._next_token(tlist, token) - def _next_token(i): - t = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=i) - if t and t.value.upper() == 'BETWEEN': - t = _next_token(tlist.token_index(t) + 1) - if t and t.value.upper() == 'AND': - t = _next_token(tlist.token_index(t) + 1) - return t - - idx = 0 - token = _next_token(idx) - added = set() + if token and token.value.upper() == 'AND': + token = self._next_token(tlist, token) + + return token + + def _split_kwds(self, tlist): + token = self._next_token(tlist) while token: prev = tlist.token_prev(token, skip_ws=False) - offset = 1 - if prev and prev.is_whitespace() and prev not in added: - tlist.tokens.pop(tlist.token_index(prev)) - offset += 1 uprev = text_type(prev) - if prev and (uprev.endswith('\n') or uprev.endswith('\r')): - nl = tlist.token_next(token) - else: - nl = self.nl() - added.add(nl) - tlist.insert_before(token, nl) - offset += 1 - token = _next_token(tlist.token_index(nl) + offset) + + if prev and prev.is_whitespace(): + tlist.tokens.remove(prev) + + if not (uprev.endswith('\n') or uprev.endswith('\r')): + tlist.insert_before(token, self.nl()) + + token = self._next_token(tlist, token) def _split_statements(self, tlist): token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML)) while token: prev = tlist.token_prev(token, skip_ws=False) if prev and prev.is_whitespace(): - tlist.tokens.pop(tlist.token_index(prev)) + tlist.tokens.remove(prev) # only break if it's not the first token - if prev: - nl = self.nl() - tlist.insert_before(token, nl) + tlist.insert_before(token, self.nl()) if prev else None token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML), idx=token) def _process(self, tlist): - func_name = '_process_%s' % tlist.__class__.__name__.lower() - func = getattr(self, func_name, self._process_default) + func_name = '_process_{cls}'.format(cls=type(tlist).__name__) + func = getattr(self, func_name.lower(), self._process_default) func(tlist) def _process_where(self, tlist): token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) - try: - tlist.insert_before(token, self.nl()) - except ValueError: # issue121, errors in statement - pass - self.indent += 1 - self._process_default(tlist) - self.indent -= 1 - - def _process_having(self, tlist): - token = tlist.token_next_by(m=(T.Keyword, 'HAVING')) - try: - tlist.insert_before(token, self.nl()) - except ValueError: # issue121, errors in statement - pass - self.indent += 1 - self._process_default(tlist) - self.indent -= 1 + # issue121, errors in statement fixed?? + tlist.insert_before(token, self.nl()) + + with indent(self): + self._process_default(tlist) def _process_parenthesis(self, tlist): - first = tlist.token_next(0) - indented = False - if first and first.ttype in (T.Keyword.DML, T.Keyword.DDL): - self.indent += 1 - tlist.tokens.insert(0, self.nl()) - indented = True - num_offset = self._get_offset( - tlist.token_next_by(m=(T.Punctuation, '('))) - self.offset += num_offset - self._process_default(tlist, stmts=not indented) - if indented: - self.indent -= 1 - self.offset -= num_offset + is_DML_DLL = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL)) + first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN) + + with indent(self, 1 if is_DML_DLL else 0): + tlist.tokens.insert(0, self.nl()) if is_DML_DLL else None + with offset(self, self._get_offset(first) + 1): + self._process_default(tlist, not is_DML_DLL) def _process_identifierlist(self, tlist): identifiers = list(tlist.get_identifiers()) - if len(identifiers) > 1 and not tlist.within(sql.Function): - first = list(identifiers[0].flatten())[0] - if self.char == '\t': - # when using tabs we don't count the actual word length - # in spaces. - num_offset = 1 - else: - num_offset = self._get_offset(first) - len(first.value) - self.offset += num_offset - position = self.offset - for token in identifiers[1:]: - # Add 1 for the "," separator - position += len(token.value) + 1 - if position > self.wrap_after: - tlist.insert_before(token, self.nl()) - position = self.offset - self.offset -= num_offset + first = next(identifiers.pop(0).flatten()) + num_offset = 1 if self.char == '\t' else self._get_offset(first) + if not tlist.within(sql.Function): + with offset(self, num_offset): + position = 0 + for token in identifiers: + # Add 1 for the "," separator + position += len(token.value) + 1 + if position > (self.wrap_after - self.offset): + tlist.insert_before(token, self.nl()) + position = 0 self._process_default(tlist) def _process_case(self, tlist): - is_first = True - num_offset = None - case = tlist.tokens[0] - outer_offset = self._get_offset(case) - len(case.value) - self.offset += outer_offset - for cond, value in tlist.get_cases(): - if is_first: - tcond = list(cond[0].flatten())[0] - is_first = False - num_offset = self._get_offset(tcond) - len(tcond.value) - self.offset += num_offset - continue - if cond is None: - token = value[0] - else: - token = cond[0] - tlist.insert_before(token, self.nl()) - # Line breaks on group level are done. Now let's add an offset of - # 5 (=length of "when", "then", "else") and process subgroups. - self.offset += 5 - self._process_default(tlist) - self.offset -= 5 - if num_offset is not None: - self.offset -= num_offset - end = tlist.token_next_by(m=(T.Keyword, 'END')) - tlist.insert_before(end, self.nl()) - self.offset -= outer_offset - - def _process_default(self, tlist, stmts=True, kwds=True): - if stmts: - self._split_statements(tlist) - if kwds: - self._split_kwds(tlist) + iterable = iter(tlist.get_cases()) + cond, _ = next(iterable) + first = next(cond[0].flatten()) + + with offset(self, self._get_offset(tlist[0])): + with offset(self, self._get_offset(first)): + for cond, value in iterable: + token = value[0] if cond is None else cond[0] + tlist.insert_before(token, self.nl()) + + # Line breaks on group level are done. let's add an offset of + # len "when ", "then ", "else " + with offset(self, len("WHEN ")): + self._process_default(tlist) + end = tlist.token_next_by(m=sql.Case.M_CLOSE) + tlist.insert_before(end, self.nl()) + + def _process_default(self, tlist, stmts=True): + self._split_statements(tlist) if stmts else None + self._split_kwds(tlist) [self._process(sgroup) for sgroup in tlist.get_sublists()] def process(self, stmt): - if isinstance(stmt, sql.Statement): - self._curr_stmt = stmt + self._curr_stmt = stmt self._process(stmt) - if isinstance(stmt, sql.Statement): - if self._last_stmt is not None: - if text_type(self._last_stmt).endswith('\n'): - nl = '\n' - else: - nl = '\n\n' - stmt.tokens.insert( - 0, sql.Token(T.Whitespace, nl)) - if self._last_stmt != stmt: - self._last_stmt = stmt + + if self._last_stmt is not None: + nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n' + stmt.tokens.insert(0, sql.Token(T.Whitespace, nl)) + + self._last_stmt = stmt + return stmt |