summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlparse/filters.py107
1 files changed, 45 insertions, 62 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 8376326..95ac74c 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -8,7 +8,7 @@
import re
from sqlparse import sql, tokens as T
-from sqlparse.compat import u, text_type
+from sqlparse.compat import text_type
from sqlparse.utils import split_unquoted_newlines
@@ -16,13 +16,10 @@ from sqlparse.utils import split_unquoted_newlines
# token process
class _CaseFilter(object):
-
ttype = None
def __init__(self, case=None):
- if case is None:
- case = 'upper'
- assert case in ['lower', 'upper', 'capitalize']
+ case = case or 'upper'
self.convert = getattr(text_type, case)
def process(self, stream):
@@ -37,33 +34,35 @@ class KeywordCaseFilter(_CaseFilter):
class IdentifierCaseFilter(_CaseFilter):
- ttype = (T.Name, T.String.Symbol)
+ ttype = T.Name, T.String.Symbol
def process(self, stream):
for ttype, value in stream:
- if ttype in self.ttype and not value.strip()[0] == '"':
+ if ttype in self.ttype and value.strip()[0] != '"':
value = self.convert(value)
yield ttype, value
class TruncateStringFilter(object):
-
def __init__(self, width, char):
- self.width = max(width, 1)
- self.char = u(char)
+ self.width = width
+ self.char = char
def process(self, stream):
for ttype, value in stream:
- if ttype is T.Literal.String.Single:
- if value[:2] == '\'\'':
- inner = value[2:-2]
- quote = u'\'\''
- else:
- inner = value[1:-1]
- quote = u'\''
- if len(inner) > self.width:
- value = u''.join((quote, inner[:self.width], self.char,
- quote))
+ if ttype != T.Literal.String.Single:
+ yield ttype, value
+ continue
+
+ if value[:2] == "''":
+ inner = value[2:-2]
+ quote = "''"
+ else:
+ inner = value[1:-1]
+ quote = "'"
+
+ if len(inner) > self.width:
+ value = ''.join((quote, inner[:self.width], self.char, quote))
yield ttype, value
@@ -71,7 +70,6 @@ class TruncateStringFilter(object):
# statement process
class StripCommentsFilter(object):
-
def _get_next_comment(self, tlist):
# TODO(andi) Comment types should be unified, see related issue38
token = tlist.token_next_by(i=sql.Comment, t=T.Comment)
@@ -81,8 +79,8 @@ class StripCommentsFilter(object):
token = self._get_next_comment(tlist)
while token:
tidx = tlist.token_index(token)
- prev = tlist.token_prev(tidx, False)
- next_ = tlist.token_next(tidx, False)
+ prev = tlist.token_prev(tidx, skip_ws=False)
+ next_ = tlist.token_next(tidx, skip_ws=False)
# Replace by whitespace if prev and next exist and if they're not
# whitespaces. This doesn't apply if prev or next is a paranthesis.
if (prev is not None and next_ is not None
@@ -100,7 +98,6 @@ class StripCommentsFilter(object):
class StripWhitespaceFilter(object):
-
def _stripws(self, tlist):
func_name = '_stripws_%s' % tlist.__class__.__name__.lower()
func = getattr(self, func_name, self._stripws_default)
@@ -122,14 +119,10 @@ class StripWhitespaceFilter(object):
# Removes newlines before commas, see issue140
last_nl = None
for token in tlist.tokens[:]:
- if token.ttype is T.Punctuation \
- and token.value == ',' \
- and last_nl is not None:
+ if last_nl and token.ttype is T.Punctuation and token.value == ',':
tlist.tokens.remove(last_nl)
- if token.is_whitespace():
- last_nl = token
- else:
- last_nl = None
+
+ last_nl = token if token.is_whitespace() else None
return self._stripws_default(tlist)
def _stripws_parenthesis(self, tlist):
@@ -140,19 +133,13 @@ class StripWhitespaceFilter(object):
self._stripws_default(tlist)
def process(self, stmt, depth=0):
- [self.process(sgroup, depth + 1)
- for sgroup in stmt.get_sublists()]
+ [self.process(sgroup, depth + 1) for sgroup in stmt.get_sublists()]
self._stripws(stmt)
- if (
- depth == 0
- and stmt.tokens
- and stmt.tokens[-1].is_whitespace()
- ):
+ if depth == 0 and stmt.tokens and stmt.tokens[-1].is_whitespace():
stmt.tokens.pop(-1)
class ReindentFilter(object):
-
def __init__(self, width=2, char=' ', line_width=None, wrap_after=0):
self.width = width
self.char = char
@@ -196,8 +183,7 @@ class ReindentFilter(object):
'SET', 'BETWEEN', 'EXCEPT', 'HAVING')
def _next_token(i):
- t = tlist.token_next_match(i, T.Keyword, split_words,
- regex=True)
+ t = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=i)
if t and t.value.upper() == 'BETWEEN':
t = _next_token(tlist.token_index(t) + 1)
if t and t.value.upper() == 'AND':
@@ -208,13 +194,13 @@ class ReindentFilter(object):
token = _next_token(idx)
added = set()
while token:
- prev = tlist.token_prev(tlist.token_index(token), False)
+ prev = tlist.token_prev(token, skip_ws=False)
offset = 1
if prev and prev.is_whitespace() and prev not in added:
tlist.tokens.pop(tlist.token_index(prev))
offset += 1
- uprev = u(prev)
- if (prev and (uprev.endswith('\n') or uprev.endswith('\r'))):
+ uprev = text_type(prev)
+ if prev and (uprev.endswith('\n') or uprev.endswith('\r')):
nl = tlist.token_next(token)
else:
nl = self.nl()
@@ -224,18 +210,17 @@ class ReindentFilter(object):
token = _next_token(tlist.token_index(nl) + offset)
def _split_statements(self, tlist):
- idx = 0
- token = tlist.token_next_by_type(idx, (T.Keyword.DDL, T.Keyword.DML))
+ token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML))
while token:
- prev = tlist.token_prev(tlist.token_index(token), False)
+ prev = tlist.token_prev(token, skip_ws=False)
if prev and prev.is_whitespace():
tlist.tokens.pop(tlist.token_index(prev))
# only break if it's not the first token
if prev:
nl = self.nl()
tlist.insert_before(token, nl)
- token = tlist.token_next_by_type(tlist.token_index(token) + 1,
- (T.Keyword.DDL, T.Keyword.DML))
+ token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML),
+ idx=token)
def _process(self, tlist):
func_name = '_process_%s' % tlist.__class__.__name__.lower()
@@ -243,7 +228,7 @@ class ReindentFilter(object):
func(tlist)
def _process_where(self, tlist):
- token = tlist.token_next_match(0, T.Keyword, 'WHERE')
+ token = tlist.token_next_by(m=(T.Keyword, 'WHERE'))
try:
tlist.insert_before(token, self.nl())
except ValueError: # issue121, errors in statement
@@ -253,7 +238,7 @@ class ReindentFilter(object):
self.indent -= 1
def _process_having(self, tlist):
- token = tlist.token_next_match(0, T.Keyword, 'HAVING')
+ token = tlist.token_next_by(m=(T.Keyword, 'HAVING'))
try:
tlist.insert_before(token, self.nl())
except ValueError: # issue121, errors in statement
@@ -270,7 +255,7 @@ class ReindentFilter(object):
tlist.tokens.insert(0, self.nl())
indented = True
num_offset = self._get_offset(
- tlist.token_next_match(0, T.Punctuation, '('))
+ tlist.token_next_by(m=(T.Punctuation, '(')))
self.offset += num_offset
self._process_default(tlist, stmts=not indented)
if indented:
@@ -323,7 +308,7 @@ class ReindentFilter(object):
self.offset -= 5
if num_offset is not None:
self.offset -= num_offset
- end = tlist.token_next_match(0, T.Keyword, 'END')
+ end = tlist.token_next_by(m=(T.Keyword, 'END'))
tlist.insert_before(end, self.nl())
self.offset -= outer_offset
@@ -340,7 +325,7 @@ class ReindentFilter(object):
self._process(stmt)
if isinstance(stmt, sql.Statement):
if self._last_stmt is not None:
- if u(self._last_stmt).endswith('\n'):
+ if text_type(self._last_stmt).endswith('\n'):
nl = '\n'
else:
nl = '\n\n'
@@ -352,7 +337,6 @@ class ReindentFilter(object):
# FIXME: Doesn't work
class RightMarginFilter(object):
-
keep_together = (
# sql.TypeCast, sql.Identifier, sql.Alias,
)
@@ -368,13 +352,12 @@ class RightMarginFilter(object):
self.line = ''
else:
self.line = token.value.splitlines()[-1]
- elif (token.is_group()
- and token.__class__ not in self.keep_together):
+ elif token.is_group() and type(token) not in self.keep_together:
token.tokens = self._process(token, token.tokens)
else:
- val = u(token)
+ val = text_type(token)
if len(self.line) + len(val) > self.width:
- match = re.search('^ +', self.line)
+ match = re.search(r'^ +', self.line)
if match is not None:
indent = match.group()
else:
@@ -389,13 +372,13 @@ class RightMarginFilter(object):
# group.tokens = self._process(group, group.tokens)
raise NotImplementedError
+
# ---------------------------
# postprocess
class SerializerUnicode(object):
-
def process(self, stmt):
- raw = u(stmt)
+ raw = text_type(stmt)
lines = split_unquoted_newlines(raw)
res = '\n'.join(line.rstrip() for line in lines)
return res
@@ -418,7 +401,7 @@ class OutputFilter(object):
else:
varname = self.varname
- has_nl = len(u(stmt).strip().splitlines()) > 1
+ has_nl = len(text_type(stmt).strip().splitlines()) > 1
stmt.tokens = self._process(stmt.tokens, varname, has_nl)
return stmt