diff options
author | Andi Albrecht <albrecht.andi@gmail.com> | 2011-09-29 11:19:26 +0200 |
---|---|---|
committer | Andi Albrecht <albrecht.andi@gmail.com> | 2011-09-29 11:19:26 +0200 |
commit | 13c1d716a7798ae1a79c524dc967e9a33c534ff4 (patch) | |
tree | 5486734ad09578e3eb538a3faf545ae0681d0190 | |
parent | ff50b33074f5c276b0cff8094e85582dcd467095 (diff) | |
download | sqlparse-13c1d716a7798ae1a79c524dc967e9a33c534ff4.tar.gz |
Detect alias for CASE statements (targets issue46).
-rw-r--r-- | sqlparse/engine/grouping.py | 11 | ||||
-rw-r--r-- | sqlparse/sql.py | 69 | ||||
-rw-r--r-- | tests/test_grouping.py | 5 |
3 files changed, 45 insertions, 40 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index cc75de4..f92a812 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -274,21 +274,20 @@ def group_where(tlist): def group_aliased(tlist): + clss = (sql.Identifier, sql.Function, sql.Case) [group_aliased(sgroup) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, (sql.Identifier, sql.Function))] + if not isinstance(sgroup, clss)] idx = 0 - token = tlist.token_next_by_instance(idx, (sql.Identifier, sql.Function)) + token = tlist.token_next_by_instance(idx, clss) while token: next_ = tlist.token_next(tlist.token_index(token)) - if next_ is not None and isinstance(next_, - (sql.Identifier, sql.Function)): + if next_ is not None and isinstance(next_, clss): grp = tlist.tokens_between(token, next_)[1:] token.tokens.extend(grp) for t in grp: tlist.tokens.remove(t) idx = tlist.token_index(token) + 1 - token = tlist.token_next_by_instance(idx, - (sql.Identifier, sql.Function)) + token = tlist.token_next_by_instance(idx, clss) def group_typecasts(tlist): diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 4d56bf3..244733b 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -315,38 +315,6 @@ class TokenList(Token): """Inserts *token* before *where*.""" self.tokens.insert(self.token_index(where), token) - -class Statement(TokenList): - """Represents a SQL statement.""" - - __slots__ = ('value', 'ttype', 'tokens') - - def get_type(self): - """Returns the type of a statement. - - The returned value is a string holding an upper-cased reprint of - the first DML or DDL keyword. If the first token in this group - isn't a DML or DDL keyword "UNKNOWN" is returned. - """ - first_token = self.token_first() - if first_token is None: - # An "empty" statement that either has not tokens at all - # or only whitespace tokens. - return 'UNKNOWN' - elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL): - return first_token.value.upper() - else: - return 'UNKNOWN' - - -class Identifier(TokenList): - """Represents an identifier. - - Identifiers may have aliases or typecasts. - """ - - __slots__ = ('value', 'ttype', 'tokens') - def has_alias(self): """Returns ``True`` if an alias is present.""" return self.get_alias() is not None @@ -359,8 +327,8 @@ class Identifier(TokenList): if alias is None: return None else: - next_ = self.token_next(0) - if next_ is None or not isinstance(next_, Identifier): + next_ = self.token_next_by_instance(0, Identifier) + if next_ is None: return None alias = next_ if isinstance(alias, Identifier): @@ -393,6 +361,39 @@ class Identifier(TokenList): return None return next_.value + + +class Statement(TokenList): + """Represents a SQL statement.""" + + __slots__ = ('value', 'ttype', 'tokens') + + def get_type(self): + """Returns the type of a statement. + + The returned value is a string holding an upper-cased reprint of + the first DML or DDL keyword. If the first token in this group + isn't a DML or DDL keyword "UNKNOWN" is returned. + """ + first_token = self.token_first() + if first_token is None: + # An "empty" statement that either has not tokens at all + # or only whitespace tokens. + return 'UNKNOWN' + elif first_token.ttype in (T.Keyword.DML, T.Keyword.DDL): + return first_token.value.upper() + else: + return 'UNKNOWN' + + +class Identifier(TokenList): + """Represents an identifier. + + Identifiers may have aliases or typecasts. + """ + + __slots__ = ('value', 'ttype', 'tokens') + def get_parent_name(self): """Return name of the parent object if any. diff --git a/tests/test_grouping.py b/tests/test_grouping.py index 134409e..a66d39d 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -157,6 +157,11 @@ class TestGrouping(TestCaseBase): self.ndiffAssertEqual(s, p.to_unicode()) self.assertEqual(p.tokens[4].get_alias(), 'view') + def test_alias_case(self): # see issue46 + p = sqlparse.parse('CASE WHEN 1 THEN 2 ELSE 3 END foo')[0] + self.assertEqual(len(p.tokens), 1) + self.assertEqual(p.tokens[0].get_alias(), 'foo') + def test_idlist_function(self): # see issue10 too p = sqlparse.parse('foo(1) x, bar')[0] self.assert_(isinstance(p.tokens[0], sql.IdentifierList)) |