diff options
author | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-14 21:20:31 -0700 |
---|---|---|
committer | Victor Uriarte <victor.m.uriarte@intel.com> | 2016-06-15 13:29:21 -0700 |
commit | 74b3464d781cbad4c39cd082daa80334aa7aed78 (patch) | |
tree | 90a4ed95baab3d513db08edf48a3e06489e5d263 /sqlparse | |
parent | af9b82e0b2d00732704fedf7d7b03dcb598dca84 (diff) | |
download | sqlparse-74b3464d781cbad4c39cd082daa80334aa7aed78.tar.gz |
Re-Write grouping functions
Diffstat (limited to 'sqlparse')
-rw-r--r-- | sqlparse/engine/grouping.py | 76 |
1 files changed, 47 insertions, 29 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 7879f76..ae214c2 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -152,46 +152,42 @@ def group_arrays(tlist): @recurse(sql.Identifier) def group_operator(tlist): - I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function, + ttypes = T_NUMERICAL + T_STRING + T_NAME + clss = (sql.SquareBrackets, sql.Parenthesis, sql.Function, sql.Identifier, sql.Operation) - # wilcards wouldn't have operations next to them - T_CYCLE = T_NUMERICAL + T_STRING + T_NAME - func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE) - tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard)) - while token: - pidx, prev_ = tlist.token_prev(tidx) - nidx, next_ = tlist.token_next(tidx) + def match(token): + return imt(token, t=(T.Operator, T.Wildcard)) - if func(prev_) and func(next_): - token.ttype = T.Operator - tlist.group_tokens(sql.Operation, pidx, nidx) - tidx = pidx + def valid(token): + return imt(token, i=clss, t=ttypes) + + def post(tlist, pidx, tidx, nidx): + tlist[tidx].ttype = T.Operator + return pidx, nidx - tidx, token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=tidx) + _group(tlist, sql.Operation, match, valid, valid, post, extend=False) -@recurse(sql.IdentifierList) def group_identifier_list(tlist): - M_ROLE = T.Keyword, ('null', 'role') - M_COMMA = T.Punctuation, ',' + m_role = T.Keyword, ('null', 'role') + m_comma = T.Punctuation, ',' + clss = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, + sql.IdentifierList, sql.Operation) + ttypes = (T_NUMERICAL + T_STRING + T_NAME + + (T.Keyword, T.Comment, T.Wildcard)) - I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, - sql.IdentifierList, sql.Operation) - T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME + - (T.Keyword, T.Comment, T.Wildcard)) + def match(token): + return imt(token, m=m_comma) - func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST) + def func(token): + return imt(token, i=clss, m=m_role, t=ttypes) - tidx, token = tlist.token_next_by(m=M_COMMA) - while token: - pidx, prev_ = tlist.token_prev(tidx) - nidx, next_ = tlist.token_next(tidx) + def post(tlist, pidx, tidx, nidx): + return pidx, nidx - if func(prev_) and func(next_): - tlist.group_tokens(sql.IdentifierList, pidx, nidx, extend=True) - tidx = pidx - tidx, token = tlist.token_next_by(m=M_COMMA, idx=tidx) + _group(tlist, sql.IdentifierList, match, + valid_left=func, valid_right=func, post=post, extend=True) @recurse(sql.Comment) @@ -309,3 +305,25 @@ def group(stmt): ]: func(stmt) return stmt + + +def _group(tlist, cls, match, + valid_left=lambda t: True, + valid_right=lambda t: True, + post=None, + extend=True): + """Groups together tokens that are joined by a middle token. ie. x < y""" + for token in list(tlist): + if token.is_group() and not isinstance(token, cls): + _group(token, cls, match, valid_left, valid_right, post, extend) + continue + if not match(token): + continue + + tidx = tlist.token_index(token) + pidx, prev_ = tlist.token_prev(tidx) + nidx, next_ = tlist.token_next(tidx) + + if valid_left(prev_) and valid_right(next_): + from_idx, to_idx = post(tlist, pidx, tidx, nidx) + tlist.group_tokens(cls, from_idx, to_idx, extend=extend) |