summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorVictor Uriarte <victor.m.uriarte@intel.com>2016-06-14 21:20:31 -0700
committerVictor Uriarte <victor.m.uriarte@intel.com>2016-06-15 13:29:21 -0700
commit74b3464d781cbad4c39cd082daa80334aa7aed78 (patch)
tree90a4ed95baab3d513db08edf48a3e06489e5d263 /sqlparse
parentaf9b82e0b2d00732704fedf7d7b03dcb598dca84 (diff)
downloadsqlparse-74b3464d781cbad4c39cd082daa80334aa7aed78.tar.gz
Re-Write grouping functions
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/engine/grouping.py76
1 files changed, 47 insertions, 29 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 7879f76..ae214c2 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -152,46 +152,42 @@ def group_arrays(tlist):
@recurse(sql.Identifier)
def group_operator(tlist):
- I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
+ ttypes = T_NUMERICAL + T_STRING + T_NAME
+ clss = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
sql.Identifier, sql.Operation)
- # wilcards wouldn't have operations next to them
- T_CYCLE = T_NUMERICAL + T_STRING + T_NAME
- func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE)
- tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard))
- while token:
- pidx, prev_ = tlist.token_prev(tidx)
- nidx, next_ = tlist.token_next(tidx)
+ def match(token):
+ return imt(token, t=(T.Operator, T.Wildcard))
- if func(prev_) and func(next_):
- token.ttype = T.Operator
- tlist.group_tokens(sql.Operation, pidx, nidx)
- tidx = pidx
+ def valid(token):
+ return imt(token, i=clss, t=ttypes)
+
+ def post(tlist, pidx, tidx, nidx):
+ tlist[tidx].ttype = T.Operator
+ return pidx, nidx
- tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=tidx)
+ _group(tlist, sql.Operation, match, valid, valid, post, extend=False)
-@recurse(sql.IdentifierList)
def group_identifier_list(tlist):
- M_ROLE = T.Keyword, ('null', 'role')
- M_COMMA = T.Punctuation, ','
+ m_role = T.Keyword, ('null', 'role')
+ m_comma = T.Punctuation, ','
+ clss = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
+ sql.IdentifierList, sql.Operation)
+ ttypes = (T_NUMERICAL + T_STRING + T_NAME +
+ (T.Keyword, T.Comment, T.Wildcard))
- I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
- sql.IdentifierList, sql.Operation)
- T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME +
- (T.Keyword, T.Comment, T.Wildcard))
+ def match(token):
+ return imt(token, m=m_comma)
- func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST)
+ def func(token):
+ return imt(token, i=clss, m=m_role, t=ttypes)
- tidx, token = tlist.token_next_by(m=M_COMMA)
- while token:
- pidx, prev_ = tlist.token_prev(tidx)
- nidx, next_ = tlist.token_next(tidx)
+ def post(tlist, pidx, tidx, nidx):
+ return pidx, nidx
- if func(prev_) and func(next_):
- tlist.group_tokens(sql.IdentifierList, pidx, nidx, extend=True)
- tidx = pidx
- tidx, token = tlist.token_next_by(m=M_COMMA, idx=tidx)
+ _group(tlist, sql.IdentifierList, match,
+ valid_left=func, valid_right=func, post=post, extend=True)
@recurse(sql.Comment)
@@ -309,3 +305,25 @@ def group(stmt):
]:
func(stmt)
return stmt
+
+
+def _group(tlist, cls, match,
+ valid_left=lambda t: True,
+ valid_right=lambda t: True,
+ post=None,
+ extend=True):
+ """Groups together tokens that are joined by a middle token. ie. x < y"""
+ for token in list(tlist):
+ if token.is_group() and not isinstance(token, cls):
+ _group(token, cls, match, valid_left, valid_right, post, extend)
+ continue
+ if not match(token):
+ continue
+
+ tidx = tlist.token_index(token)
+ pidx, prev_ = tlist.token_prev(tidx)
+ nidx, next_ = tlist.token_next(tidx)
+
+ if valid_left(prev_) and valid_right(next_):
+ from_idx, to_idx = post(tlist, pidx, tidx, nidx)
+ tlist.group_tokens(cls, from_idx, to_idx, extend=extend)